transcribe.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import argparse
  2. import os
  3. import traceback
  4. import warnings
  5. from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable
  6. import numpy as np
  7. import torch
  8. import tqdm
  9. from .audio import (
  10. FRAMES_PER_SECOND,
  11. HOP_LENGTH,
  12. N_FRAMES,
  13. N_SAMPLES,
  14. SAMPLE_RATE,
  15. log_mel_spectrogram,
  16. pad_or_trim,
  17. )
  18. from .decoding import DecodingOptions, DecodingResult
  19. from .timing import add_word_timestamps
  20. from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
  21. from .utils import (
  22. exact_div,
  23. format_timestamp,
  24. get_writer,
  25. make_safe,
  26. optional_float,
  27. optional_int,
  28. str2bool,
  29. )
  30. if TYPE_CHECKING:
  31. from .model import Whisper
  32. def transcribe(
  33. model: "Whisper",
  34. audio: Union[str, np.ndarray, torch.Tensor],
  35. *,
  36. verbose: Optional[bool] = None,
  37. temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
  38. compression_ratio_threshold: Optional[float] = 2.4,
  39. logprob_threshold: Optional[float] = -1.0,
  40. no_speech_threshold: Optional[float] = 0.6,
  41. condition_on_previous_text: bool = True,
  42. initial_prompt: Optional[str] = None,
  43. word_timestamps: bool = False,
  44. prepend_punctuations: str = "\"'“¿([{-",
  45. append_punctuations: str = "\"'.。,,!!??::”)]}、",
  46. progress_callback: Optional[Callable[[float], None]] = None,
  47. **decode_options,
  48. ):
  49. """
  50. Transcribe an audio file using Whisper
  51. Parameters
  52. ----------
  53. model: Whisper
  54. The Whisper model instance
  55. audio: Union[str, np.ndarray, torch.Tensor]
  56. The path to the audio file to open, or the audio waveform
  57. verbose: bool
  58. Whether to display the text being decoded to the console. If True, displays all the details,
  59. If False, displays minimal details. If None, does not display anything
  60. temperature: Union[float, Tuple[float, ...]]
  61. Temperature for sampling. It can be a tuple of temperatures, which will be successively used
  62. upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
  63. compression_ratio_threshold: float
  64. If the gzip compression ratio is above this value, treat as failed
  65. logprob_threshold: float
  66. If the average log probability over sampled tokens is below this value, treat as failed
  67. no_speech_threshold: float
  68. If the no_speech probability is higher than this value AND the average log probability
  69. over sampled tokens is below `logprob_threshold`, consider the segment as silent
  70. condition_on_previous_text: bool
  71. if True, the previous output of the model is provided as a prompt for the next window;
  72. disabling may make the text inconsistent across windows, but the model becomes less prone to
  73. getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
  74. word_timestamps: bool
  75. Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
  76. and include the timestamps for each word in each segment.
  77. prepend_punctuations: str
  78. If word_timestamps is True, merge these punctuation symbols with the next word
  79. append_punctuations: str
  80. If word_timestamps is True, merge these punctuation symbols with the previous word
  81. initial_prompt: Optional[str]
  82. Optional text to provide as a prompt for the first window. This can be used to provide, or
  83. "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
  84. to make it more likely to predict those word correctly.
  85. decode_options: dict
  86. Keyword arguments to construct `DecodingOptions` instances
  87. Returns
  88. -------
  89. A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
  90. the spoken language ("language"), which is detected when `decode_options["language"]` is None.
  91. """
  92. dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
  93. if model.device == torch.device("cpu"):
  94. if torch.cuda.is_available():
  95. warnings.warn("Performing inference on CPU when CUDA is available")
  96. if dtype == torch.float16:
  97. warnings.warn("FP16 is not supported on CPU; using FP32 instead")
  98. dtype = torch.float32
  99. if dtype == torch.float32:
  100. decode_options["fp16"] = False
  101. # Pad 30-seconds of silence to the input audio, for slicing
  102. mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
  103. content_frames = mel.shape[-1] - N_FRAMES
  104. if decode_options.get("language", None) is None:
  105. if not model.is_multilingual:
  106. decode_options["language"] = "en"
  107. else:
  108. if verbose:
  109. print(
  110. "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
  111. )
  112. mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
  113. _, probs = model.detect_language(mel_segment)
  114. decode_options["language"] = max(probs, key=probs.get)
  115. if verbose is not None:
  116. print(
  117. f"Detected language: {LANGUAGES[decode_options['language']].title()}"
  118. )
  119. language: str = decode_options["language"]
  120. task: str = decode_options.get("task", "transcribe")
  121. tokenizer = get_tokenizer(
  122. model.is_multilingual,
  123. num_languages=model.num_languages,
  124. language=language,
  125. task=task,
  126. )
  127. if word_timestamps and task == "translate":
  128. warnings.warn("Word-level timestamps on translations may not be reliable.")
  129. def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
  130. temperatures = (
  131. [temperature] if isinstance(temperature, (int, float)) else temperature
  132. )
  133. decode_result = None
  134. for t in temperatures:
  135. kwargs = {**decode_options}
  136. if t > 0:
  137. # disable beam_size and patience when t > 0
  138. kwargs.pop("beam_size", None)
  139. kwargs.pop("patience", None)
  140. else:
  141. # disable best_of when t == 0
  142. kwargs.pop("best_of", None)
  143. options = DecodingOptions(**kwargs, temperature=t)
  144. decode_result = model.decode(segment, options)
  145. needs_fallback = False
  146. if (
  147. compression_ratio_threshold is not None
  148. and decode_result.compression_ratio > compression_ratio_threshold
  149. ):
  150. needs_fallback = True # too repetitive
  151. if (
  152. logprob_threshold is not None
  153. and decode_result.avg_logprob < logprob_threshold
  154. ):
  155. needs_fallback = True # average log probability is too low
  156. if (
  157. no_speech_threshold is not None
  158. and decode_result.no_speech_prob > no_speech_threshold
  159. ):
  160. needs_fallback = False # silence
  161. if not needs_fallback:
  162. break
  163. return decode_result
  164. seek = 0
  165. input_stride = exact_div(
  166. N_FRAMES, model.dims.n_audio_ctx
  167. ) # mel frames per output token: 2
  168. time_precision = (
  169. input_stride * HOP_LENGTH / SAMPLE_RATE
  170. ) # time per output token: 0.02 (seconds)
  171. all_tokens = []
  172. all_segments = []
  173. prompt_reset_since = 0
  174. if initial_prompt is not None:
  175. initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
  176. all_tokens.extend(initial_prompt_tokens)
  177. else:
  178. initial_prompt_tokens = []
  179. def new_segment(
  180. *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
  181. ):
  182. tokens = tokens.tolist()
  183. text_tokens = [token for token in tokens if token < tokenizer.eot]
  184. return {
  185. "seek": seek,
  186. "start": start,
  187. "end": end,
  188. "text": tokenizer.decode(text_tokens),
  189. "tokens": tokens,
  190. "temperature": result.temperature,
  191. "avg_logprob": result.avg_logprob,
  192. "compression_ratio": result.compression_ratio,
  193. "no_speech_prob": result.no_speech_prob,
  194. }
  195. # show the progress bar when verbose is False (if True, transcribed text will be printed)
  196. with tqdm.tqdm(
  197. total=content_frames, unit="frames", disable=verbose is not False
  198. ) as pbar:
  199. last_speech_timestamp = 0.0
  200. while seek < content_frames:
  201. if progress_callback is not None:
  202. progress_value = seek / content_frames
  203. progress_callback(progress_value)
  204. time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
  205. mel_segment = mel[:, seek: seek + N_FRAMES]
  206. segment_size = min(N_FRAMES, content_frames - seek)
  207. segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
  208. mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
  209. decode_options["prompt"] = all_tokens[prompt_reset_since:]
  210. result: DecodingResult = decode_with_fallback(mel_segment)
  211. tokens = torch.tensor(result.tokens)
  212. if no_speech_threshold is not None:
  213. # no voice activity check
  214. should_skip = result.no_speech_prob > no_speech_threshold
  215. if (
  216. logprob_threshold is not None
  217. and result.avg_logprob > logprob_threshold
  218. ):
  219. # don't skip if the logprob is high enough, despite the no_speech_prob
  220. should_skip = False
  221. if should_skip:
  222. seek += segment_size # fast-forward to the next segment boundary
  223. continue
  224. previous_seek = seek
  225. current_segments = []
  226. timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
  227. single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
  228. consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
  229. consecutive.add_(1)
  230. if len(consecutive) > 0:
  231. # if the output contains two consecutive timestamp tokens
  232. slices = consecutive.tolist()
  233. if single_timestamp_ending:
  234. slices.append(len(tokens))
  235. last_slice = 0
  236. for current_slice in slices:
  237. sliced_tokens = tokens[last_slice:current_slice]
  238. start_timestamp_pos = (
  239. sliced_tokens[0].item() - tokenizer.timestamp_begin
  240. )
  241. end_timestamp_pos = (
  242. sliced_tokens[-1].item() - tokenizer.timestamp_begin
  243. )
  244. current_segments.append(
  245. new_segment(
  246. start=time_offset + start_timestamp_pos * time_precision,
  247. end=time_offset + end_timestamp_pos * time_precision,
  248. tokens=sliced_tokens,
  249. result=result,
  250. )
  251. )
  252. last_slice = current_slice
  253. if single_timestamp_ending:
  254. # single timestamp at the end means no speech after the last timestamp.
  255. seek += segment_size
  256. else:
  257. # otherwise, ignore the unfinished segment and seek to the last timestamp
  258. last_timestamp_pos = (
  259. tokens[last_slice - 1].item() - tokenizer.timestamp_begin
  260. )
  261. seek += last_timestamp_pos * input_stride
  262. else:
  263. duration = segment_duration
  264. timestamps = tokens[timestamp_tokens.nonzero().flatten()]
  265. if (
  266. len(timestamps) > 0
  267. and timestamps[-1].item() != tokenizer.timestamp_begin
  268. ):
  269. # no consecutive timestamps but it has a timestamp; use the last one.
  270. last_timestamp_pos = (
  271. timestamps[-1].item() - tokenizer.timestamp_begin
  272. )
  273. duration = last_timestamp_pos * time_precision
  274. current_segments.append(
  275. new_segment(
  276. start=time_offset,
  277. end=time_offset + duration,
  278. tokens=tokens,
  279. result=result,
  280. )
  281. )
  282. seek += segment_size
  283. if word_timestamps:
  284. add_word_timestamps(
  285. segments=current_segments,
  286. model=model,
  287. tokenizer=tokenizer,
  288. mel=mel_segment,
  289. num_frames=segment_size,
  290. prepend_punctuations=prepend_punctuations,
  291. append_punctuations=append_punctuations,
  292. last_speech_timestamp=last_speech_timestamp,
  293. )
  294. word_end_timestamps = [
  295. w["end"] for s in current_segments for w in s["words"]
  296. ]
  297. if len(word_end_timestamps) > 0:
  298. last_speech_timestamp = word_end_timestamps[-1]
  299. if not single_timestamp_ending and len(word_end_timestamps) > 0:
  300. seek_shift = round(
  301. (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
  302. )
  303. if seek_shift > 0:
  304. seek = previous_seek + seek_shift
  305. if verbose:
  306. for segment in current_segments:
  307. start, end, text = segment["start"], segment["end"], segment["text"]
  308. line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
  309. print(make_safe(line))
  310. # if a segment is instantaneous or does not contain text, clear it
  311. for i, segment in enumerate(current_segments):
  312. if segment["start"] == segment["end"] or segment["text"].strip() == "":
  313. segment["text"] = ""
  314. segment["tokens"] = []
  315. segment["words"] = []
  316. all_segments.extend(
  317. [
  318. {"id": i, **segment}
  319. for i, segment in enumerate(
  320. current_segments, start=len(all_segments)
  321. )
  322. ]
  323. )
  324. all_tokens.extend(
  325. [token for segment in current_segments for token in segment["tokens"]]
  326. )
  327. if not condition_on_previous_text or result.temperature > 0.5:
  328. # do not feed the prompt tokens if a high temperature was used
  329. prompt_reset_since = len(all_tokens)
  330. # update progress bar
  331. pbar.update(min(content_frames, seek) - previous_seek)
  332. return dict(
  333. text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
  334. segments=all_segments,
  335. language=language,
  336. )
  337. def cli():
  338. from . import available_models
  339. def valid_model_name(name):
  340. if name in available_models() or os.path.exists(name):
  341. return name
  342. raise ValueError(
  343. f"model should be one of {available_models()} or path to a model checkpoint"
  344. )
  345. # fmt: off
  346. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  347. parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
  348. parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
  349. parser.add_argument("--model_dir", type=str, default=None,
  350. help="the path to save model files; uses ~/.cache/whisper by default")
  351. parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu",
  352. help="device to use for PyTorch inference")
  353. parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
  354. parser.add_argument("--output_format", "-f", type=str, default="all",
  355. choices=["txt", "vtt", "srt", "tsv", "json", "all"],
  356. help="format of the output file; if not specified, all available formats will be produced")
  357. parser.add_argument("--verbose", type=str2bool, default=True,
  358. help="whether to print out the progress and debug messages")
  359. parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"],
  360. help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
  361. parser.add_argument("--language", type=str, default=None,
  362. choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
  363. help="language spoken in the audio, specify None to perform language detection")
  364. parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
  365. parser.add_argument("--best_of", type=optional_int, default=5,
  366. help="number of candidates when sampling with non-zero temperature")
  367. parser.add_argument("--beam_size", type=optional_int, default=5,
  368. help="number of beams in beam search, only applicable when temperature is zero")
  369. parser.add_argument("--patience", type=float, default=None,
  370. help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
  371. parser.add_argument("--length_penalty", type=float, default=None,
  372. help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
  373. parser.add_argument("--suppress_tokens", type=str, default="-1",
  374. help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
  375. parser.add_argument("--initial_prompt", type=str, default=None,
  376. help="optional text to provide as a prompt for the first window.")
  377. parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
  378. help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
  379. parser.add_argument("--fp16", type=str2bool, default=True,
  380. help="whether to perform inference in fp16; True by default")
  381. parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2,
  382. help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
  383. parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4,
  384. help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
  385. parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0,
  386. help="if the average log probability is lower than this value, treat the decoding as failed")
  387. parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6,
  388. help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
  389. parser.add_argument("--word_timestamps", type=str2bool, default=False,
  390. help="(experimental) extract word-level timestamps and refine the results based on them")
  391. parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-",
  392. help="if word_timestamps is True, merge these punctuation symbols with the next word")
  393. parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、",
  394. help="if word_timestamps is True, merge these punctuation symbols with the previous word")
  395. parser.add_argument("--highlight_words", type=str2bool, default=False,
  396. help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
  397. parser.add_argument("--max_line_width", type=optional_int, default=None,
  398. help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
  399. parser.add_argument("--max_line_count", type=optional_int, default=None,
  400. help="(requires --word_timestamps True) the maximum number of lines in a segment")
  401. parser.add_argument("--max_words_per_line", type=optional_int, default=None,
  402. help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
  403. parser.add_argument("--threads", type=optional_int, default=0,
  404. help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
  405. # fmt: on
  406. args = parser.parse_args().__dict__
  407. model_name: str = args.pop("model")
  408. model_dir: str = args.pop("model_dir")
  409. output_dir: str = args.pop("output_dir")
  410. output_format: str = args.pop("output_format")
  411. device: str = args.pop("device")
  412. os.makedirs(output_dir, exist_ok=True)
  413. if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
  414. if args["language"] is not None:
  415. warnings.warn(
  416. f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
  417. )
  418. args["language"] = "en"
  419. temperature = args.pop("temperature")
  420. if (increment := args.pop("temperature_increment_on_fallback")) is not None:
  421. temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
  422. else:
  423. temperature = [temperature]
  424. if (threads := args.pop("threads")) > 0:
  425. torch.set_num_threads(threads)
  426. from . import load_model
  427. model = load_model(model_name, device=device, download_root=model_dir)
  428. writer = get_writer(output_format, output_dir)
  429. word_options = [
  430. "highlight_words",
  431. "max_line_count",
  432. "max_line_width",
  433. "max_words_per_line",
  434. ]
  435. if not args["word_timestamps"]:
  436. for option in word_options:
  437. if args[option]:
  438. parser.error(f"--{option} requires --word_timestamps True")
  439. if args["max_line_count"] and not args["max_line_width"]:
  440. warnings.warn("--max_line_count has no effect without --max_line_width")
  441. if args["max_words_per_line"] and args["max_line_width"]:
  442. warnings.warn("--max_words_per_line has no effect with --max_line_width")
  443. writer_args = {arg: args.pop(arg) for arg in word_options}
  444. for audio_path in args.pop("audio"):
  445. try:
  446. result = transcribe(model, audio_path, temperature=temperature, **args)
  447. writer(result, audio_path, **writer_args)
  448. except Exception as e:
  449. traceback.print_exc()
  450. print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
  451. if __name__ == "__main__":
  452. cli()