123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461 |
- import argparse
- import os
- import warnings
- from typing import TYPE_CHECKING, Optional, Tuple, Union
- import numpy as np
- import torch
- import tqdm
- from .audio import (
- FRAMES_PER_SECOND,
- HOP_LENGTH,
- N_FRAMES,
- N_SAMPLES,
- SAMPLE_RATE,
- log_mel_spectrogram,
- pad_or_trim,
- )
- from .decoding import DecodingOptions, DecodingResult
- from .timing import add_word_timestamps
- from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
- from .utils import (
- exact_div,
- format_timestamp,
- get_writer,
- make_safe,
- optional_float,
- optional_int,
- str2bool,
- )
- if TYPE_CHECKING:
- from .model import Whisper
- def transcribe(
- model: "Whisper",
- audio: Union[str, np.ndarray, torch.Tensor],
- *,
- verbose: Optional[bool] = None,
- temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
- compression_ratio_threshold: Optional[float] = 2.4,
- logprob_threshold: Optional[float] = -1.0,
- no_speech_threshold: Optional[float] = 0.6,
- condition_on_previous_text: bool = True,
- initial_prompt: Optional[str] = None,
- word_timestamps: bool = False,
- prepend_punctuations: str = "\"'“¿([{-",
- append_punctuations: str = "\"'.。,,!!??::”)]}、",
- **decode_options,
- ):
- """
- Transcribe an audio file using Whisper
- Parameters
- ----------
- model: Whisper
- The Whisper model instance
- audio: Union[str, np.ndarray, torch.Tensor]
- The path to the audio file to open, or the audio waveform
- verbose: bool
- Whether to display the text being decoded to the console. If True, displays all the details,
- If False, displays minimal details. If None, does not display anything
- temperature: Union[float, Tuple[float, ...]]
- Temperature for sampling. It can be a tuple of temperatures, which will be successively used
- upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
- compression_ratio_threshold: float
- If the gzip compression ratio is above this value, treat as failed
- logprob_threshold: float
- If the average log probability over sampled tokens is below this value, treat as failed
- no_speech_threshold: float
- If the no_speech probability is higher than this value AND the average log probability
- over sampled tokens is below `logprob_threshold`, consider the segment as silent
- condition_on_previous_text: bool
- if True, the previous output of the model is provided as a prompt for the next window;
- disabling may make the text inconsistent across windows, but the model becomes less prone to
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
- word_timestamps: bool
- Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
- and include the timestamps for each word in each segment.
- prepend_punctuations: str
- If word_timestamps is True, merge these punctuation symbols with the next word
- append_punctuations: str
- If word_timestamps is True, merge these punctuation symbols with the previous word
- initial_prompt: Optional[str]
- Optional text to provide as a prompt for the first window. This can be used to provide, or
- "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
- to make it more likely to predict those word correctly.
- decode_options: dict
- Keyword arguments to construct `DecodingOptions` instances
- Returns
- -------
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
- """
- dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
- if model.device == torch.device("cpu"):
- if torch.cuda.is_available():
- warnings.warn("Performing inference on CPU when CUDA is available")
- if dtype == torch.float16:
- warnings.warn("FP16 is not supported on CPU; using FP32 instead")
- dtype = torch.float32
- if dtype == torch.float32:
- decode_options["fp16"] = False
- # Pad 30-seconds of silence to the input audio, for slicing
- mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
- content_frames = mel.shape[-1] - N_FRAMES
- if decode_options.get("language", None) is None:
- if not model.is_multilingual:
- decode_options["language"] = "en"
- else:
- if verbose:
- print(
- "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
- )
- mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
- _, probs = model.detect_language(mel_segment)
- decode_options["language"] = max(probs, key=probs.get)
- if verbose is not None:
- print(
- f"Detected language: {LANGUAGES[decode_options['language']].title()}"
- )
- language: str = decode_options["language"]
- task: str = decode_options.get("task", "transcribe")
- tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
- if word_timestamps and task == "translate":
- warnings.warn("Word-level timestamps on translations may not be reliable.")
- def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
- temperatures = (
- [temperature] if isinstance(temperature, (int, float)) else temperature
- )
- decode_result = None
- for t in temperatures:
- kwargs = {**decode_options}
- if t > 0:
- # disable beam_size and patience when t > 0
- kwargs.pop("beam_size", None)
- kwargs.pop("patience", None)
- else:
- # disable best_of when t == 0
- kwargs.pop("best_of", None)
- options = DecodingOptions(**kwargs, temperature=t)
- decode_result = model.decode(segment, options)
- needs_fallback = False
- if (
- compression_ratio_threshold is not None
- and decode_result.compression_ratio > compression_ratio_threshold
- ):
- needs_fallback = True # too repetitive
- if (
- logprob_threshold is not None
- and decode_result.avg_logprob < logprob_threshold
- ):
- needs_fallback = True # average log probability is too low
- if (
- no_speech_threshold is not None
- and decode_result.no_speech_prob > no_speech_threshold
- ):
- needs_fallback = False # silence
- if not needs_fallback:
- break
- return decode_result
- seek = 0
- input_stride = exact_div(
- N_FRAMES, model.dims.n_audio_ctx
- ) # mel frames per output token: 2
- time_precision = (
- input_stride * HOP_LENGTH / SAMPLE_RATE
- ) # time per output token: 0.02 (seconds)
- all_tokens = []
- all_segments = []
- prompt_reset_since = 0
- if initial_prompt is not None:
- initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
- all_tokens.extend(initial_prompt_tokens)
- else:
- initial_prompt_tokens = []
- def new_segment(
- *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
- ):
- tokens = tokens.tolist()
- text_tokens = [token for token in tokens if token < tokenizer.eot]
- return {
- "seek": seek,
- "start": start,
- "end": end,
- "text": tokenizer.decode(text_tokens),
- "tokens": tokens,
- "temperature": result.temperature,
- "avg_logprob": result.avg_logprob,
- "compression_ratio": result.compression_ratio,
- "no_speech_prob": result.no_speech_prob,
- }
- # show the progress bar when verbose is False (if True, transcribed text will be printed)
- with tqdm.tqdm(
- total=content_frames, unit="frames", disable=verbose is not False
- ) as pbar:
- last_speech_timestamp = 0.0
- while seek < content_frames:
- time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
- mel_segment = mel[:, seek : seek + N_FRAMES]
- segment_size = min(N_FRAMES, content_frames - seek)
- segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
- mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
- result: DecodingResult = decode_with_fallback(mel_segment)
- tokens = torch.tensor(result.tokens)
- if no_speech_threshold is not None:
- # no voice activity check
- should_skip = result.no_speech_prob > no_speech_threshold
- if (
- logprob_threshold is not None
- and result.avg_logprob > logprob_threshold
- ):
- # don't skip if the logprob is high enough, despite the no_speech_prob
- should_skip = False
- if should_skip:
- seek += segment_size # fast-forward to the next segment boundary
- continue
- previous_seek = seek
- current_segments = []
- timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
- single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
- consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
- consecutive.add_(1)
- if len(consecutive) > 0:
- # if the output contains two consecutive timestamp tokens
- slices = consecutive.tolist()
- if single_timestamp_ending:
- slices.append(len(tokens))
- last_slice = 0
- for current_slice in slices:
- sliced_tokens = tokens[last_slice:current_slice]
- start_timestamp_pos = (
- sliced_tokens[0].item() - tokenizer.timestamp_begin
- )
- end_timestamp_pos = (
- sliced_tokens[-1].item() - tokenizer.timestamp_begin
- )
- current_segments.append(
- new_segment(
- start=time_offset + start_timestamp_pos * time_precision,
- end=time_offset + end_timestamp_pos * time_precision,
- tokens=sliced_tokens,
- result=result,
- )
- )
- last_slice = current_slice
- if single_timestamp_ending:
- # single timestamp at the end means no speech after the last timestamp.
- seek += segment_size
- else:
- # otherwise, ignore the unfinished segment and seek to the last timestamp
- last_timestamp_pos = (
- tokens[last_slice - 1].item() - tokenizer.timestamp_begin
- )
- seek += last_timestamp_pos * input_stride
- else:
- duration = segment_duration
- timestamps = tokens[timestamp_tokens.nonzero().flatten()]
- if (
- len(timestamps) > 0
- and timestamps[-1].item() != tokenizer.timestamp_begin
- ):
- # no consecutive timestamps but it has a timestamp; use the last one.
- last_timestamp_pos = (
- timestamps[-1].item() - tokenizer.timestamp_begin
- )
- duration = last_timestamp_pos * time_precision
- current_segments.append(
- new_segment(
- start=time_offset,
- end=time_offset + duration,
- tokens=tokens,
- result=result,
- )
- )
- seek += segment_size
- if word_timestamps:
- add_word_timestamps(
- segments=current_segments,
- model=model,
- tokenizer=tokenizer,
- mel=mel_segment,
- num_frames=segment_size,
- prepend_punctuations=prepend_punctuations,
- append_punctuations=append_punctuations,
- last_speech_timestamp=last_speech_timestamp,
- )
- word_end_timestamps = [
- w["end"] for s in current_segments for w in s["words"]
- ]
- if len(word_end_timestamps) > 0:
- last_speech_timestamp = word_end_timestamps[-1]
- if not single_timestamp_ending and len(word_end_timestamps) > 0:
- seek_shift = round(
- (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
- )
- if seek_shift > 0:
- seek = previous_seek + seek_shift
- if verbose:
- for segment in current_segments:
- start, end, text = segment["start"], segment["end"], segment["text"]
- line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
- print(make_safe(line))
- # if a segment is instantaneous or does not contain text, clear it
- for i, segment in enumerate(current_segments):
- if segment["start"] == segment["end"] or segment["text"].strip() == "":
- segment["text"] = ""
- segment["tokens"] = []
- segment["words"] = []
- all_segments.extend(
- [
- {"id": i, **segment}
- for i, segment in enumerate(
- current_segments, start=len(all_segments)
- )
- ]
- )
- all_tokens.extend(
- [token for segment in current_segments for token in segment["tokens"]]
- )
- if not condition_on_previous_text or result.temperature > 0.5:
- # do not feed the prompt tokens if a high temperature was used
- prompt_reset_since = len(all_tokens)
- # update progress bar
- pbar.update(min(content_frames, seek) - previous_seek)
- return dict(
- text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
- segments=all_segments,
- language=language,
- )
- def cli():
- from . import available_models
- # fmt: off
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
- parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
- parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
- parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
- parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
- parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
- parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
- parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
- parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
- parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
- parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
- parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
- parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
- parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
- parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
- parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
- parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
- parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
- parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
- parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
- parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
- parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
- parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
- parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
- parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
- parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
- parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
- parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
- parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
- # fmt: on
- args = parser.parse_args().__dict__
- model_name: str = args.pop("model")
- model_dir: str = args.pop("model_dir")
- output_dir: str = args.pop("output_dir")
- output_format: str = args.pop("output_format")
- device: str = args.pop("device")
- os.makedirs(output_dir, exist_ok=True)
- if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
- if args["language"] is not None:
- warnings.warn(
- f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
- )
- args["language"] = "en"
- temperature = args.pop("temperature")
- if (increment := args.pop("temperature_increment_on_fallback")) is not None:
- temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
- else:
- temperature = [temperature]
- if (threads := args.pop("threads")) > 0:
- torch.set_num_threads(threads)
- from . import load_model
- model = load_model(model_name, device=device, download_root=model_dir)
- writer = get_writer(output_format, output_dir)
- word_options = ["highlight_words", "max_line_count", "max_line_width"]
- if not args["word_timestamps"]:
- for option in word_options:
- if args[option]:
- parser.error(f"--{option} requires --word_timestamps True")
- if args["max_line_count"] and not args["max_line_width"]:
- warnings.warn("--max_line_count has no effect without --max_line_width")
- writer_args = {arg: args.pop(arg) for arg in word_options}
- for audio_path in args.pop("audio"):
- result = transcribe(model, audio_path, temperature=temperature, **args)
- writer(result, audio_path, writer_args)
- if __name__ == "__main__":
- cli()
|