decoding.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. from dataclasses import dataclass, field
  2. from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. from torch.distributions import Categorical
  8. from .audio import CHUNK_LENGTH
  9. from .tokenizer import Tokenizer, get_tokenizer
  10. from .utils import compression_ratio
  11. if TYPE_CHECKING:
  12. from .model import Whisper
  13. @torch.no_grad()
  14. def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
  15. """
  16. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  17. of the most probable language tokens and the probability distribution over all language tokens.
  18. This is performed outside the main decode loop in order to not interfere with kv-caching.
  19. Returns
  20. -------
  21. language_tokens : Tensor, shape = (n_audio,)
  22. ids of the most probable language tokens, which appears after the startoftranscript token.
  23. language_probs : List[Dict[str, float]], length = n_audio
  24. list of dictionaries containing the probability distribution over all languages.
  25. """
  26. if tokenizer is None:
  27. tokenizer = get_tokenizer(model.is_multilingual)
  28. if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
  29. raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
  30. single = mel.ndim == 2
  31. if single:
  32. mel = mel.unsqueeze(0)
  33. # skip encoder forward pass if already-encoded audio features were given
  34. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  35. mel = model.encoder(mel)
  36. # forward pass using a single token, startoftranscript
  37. n_audio = mel.shape[0]
  38. x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
  39. logits = model.logits(x, mel)[:, 0]
  40. # collect detected languages; suppress all non-language tokens
  41. mask = torch.ones(logits.shape[-1], dtype=torch.bool)
  42. mask[list(tokenizer.all_language_tokens)] = False
  43. logits[:, mask] = -np.inf
  44. language_tokens = logits.argmax(dim=-1)
  45. language_token_probs = logits.softmax(dim=-1).cpu()
  46. language_probs = [
  47. {
  48. c: language_token_probs[i, j].item()
  49. for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
  50. }
  51. for i in range(n_audio)
  52. ]
  53. if single:
  54. language_tokens = language_tokens[0]
  55. language_probs = language_probs[0]
  56. return language_tokens, language_probs
  57. @dataclass(frozen=True)
  58. class DecodingOptions:
  59. task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
  60. language: Optional[str] = None # language that the audio is in; uses detected language if None
  61. # sampling-related options
  62. temperature: float = 0.0
  63. sample_len: Optional[int] = None # maximum number of tokens to sample
  64. best_of: Optional[int] = None # number of independent samples to collect, when t > 0
  65. beam_size: Optional[int] = None # number of beams in beam search, when t == 0
  66. patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
  67. # options for ranking generations (either beams or best-of-N samples)
  68. length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
  69. # prompt, prefix, and token suppression
  70. prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
  71. prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
  72. suppress_blank: bool = True # this will suppress blank outputs
  73. # list of tokens ids (or comma-separated token ids) to suppress
  74. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  75. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  76. # timestamp sampling options
  77. without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
  78. max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
  79. # implementation details
  80. fp16: bool = True # use fp16 for most of the calculation
  81. @dataclass(frozen=True)
  82. class DecodingResult:
  83. audio_features: Tensor
  84. language: str
  85. language_probs: Optional[Dict[str, float]] = None
  86. tokens: List[int] = field(default_factory=list)
  87. text: str = ""
  88. avg_logprob: float = np.nan
  89. no_speech_prob: float = np.nan
  90. temperature: float = np.nan
  91. compression_ratio: float = np.nan
  92. class Inference:
  93. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  94. """Perform a forward pass on the decoder and return per-token logits"""
  95. raise NotImplementedError
  96. def rearrange_kv_cache(self, source_indices) -> None:
  97. """Update the key-value cache according to the updated beams"""
  98. raise NotImplementedError
  99. def cleanup_caching(self) -> None:
  100. """Clean up any resources or hooks after decoding is finished"""
  101. pass
  102. class PyTorchInference(Inference):
  103. def __init__(self, model: "Whisper", initial_token_length: int):
  104. self.model: "Whisper" = model
  105. self.initial_token_length = initial_token_length
  106. self.kv_cache = {}
  107. self.hooks = []
  108. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  109. if not self.kv_cache:
  110. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  111. if tokens.shape[-1] > self.initial_token_length:
  112. # only need to use the last token except in the first forward pass
  113. tokens = tokens[:, -1:]
  114. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  115. def cleanup_caching(self):
  116. for hook in self.hooks:
  117. hook.remove()
  118. self.kv_cache = {}
  119. self.hooks = []
  120. def rearrange_kv_cache(self, source_indices):
  121. for module, tensor in self.kv_cache.items():
  122. # update the key/value cache to contain the selected sequences
  123. self.kv_cache[module] = tensor[source_indices].detach()
  124. class SequenceRanker:
  125. def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
  126. """
  127. Given a list of groups of samples and their cumulative log probabilities,
  128. return the indices of the samples in each group to select as the final result
  129. """
  130. raise NotImplementedError
  131. class MaximumLikelihoodRanker(SequenceRanker):
  132. """
  133. Select the sample with the highest log probabilities, penalized using either
  134. a simple length normalization or Google NMT paper's length penalty
  135. """
  136. def __init__(self, length_penalty: Optional[float]):
  137. self.length_penalty = length_penalty
  138. def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
  139. def scores(logprobs, lengths):
  140. result = []
  141. for logprob, length in zip(logprobs, lengths):
  142. if self.length_penalty is None:
  143. penalty = length
  144. else:
  145. # from the Google NMT paper
  146. penalty = ((5 + length) / 6) ** self.length_penalty
  147. result.append(logprob / penalty)
  148. return result
  149. # get the sequence with the highest score
  150. lengths = [[len(t) for t in s] for s in tokens]
  151. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  152. class TokenDecoder:
  153. def reset(self):
  154. """Initialize any stateful variables for decoding a new sequence"""
  155. def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
  156. """Specify how to select the next token, based on the current trace and logits
  157. Parameters
  158. ----------
  159. tokens : Tensor, shape = (n_batch, current_sequence_length)
  160. all tokens in the context so far, including the prefix and sot_sequence tokens
  161. logits : Tensor, shape = (n_batch, vocab_size)
  162. per-token logits of the probability distribution at the current step
  163. sum_logprobs : Tensor, shape = (n_batch)
  164. cumulative log probabilities for each sequence
  165. Returns
  166. -------
  167. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  168. the tokens, appended with the selected next token
  169. completed : bool
  170. True if all sequences has reached the end of text
  171. """
  172. raise NotImplementedError
  173. def finalize(
  174. self, tokens: Tensor, sum_logprobs: Tensor
  175. ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
  176. """Finalize search and return the final candidate sequences
  177. Parameters
  178. ----------
  179. tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
  180. all tokens in the context so far, including the prefix and sot_sequence
  181. sum_logprobs : Tensor, shape = (n_audio, n_group)
  182. cumulative log probabilities for each sequence
  183. Returns
  184. -------
  185. tokens : Sequence[Sequence[Tensor]], length = n_audio
  186. sequence of Tensors containing candidate token sequences, for each audio input
  187. sum_logprobs : List[List[float]], length = n_audio
  188. sequence of cumulative log probabilities corresponding to the above
  189. """
  190. raise NotImplementedError
  191. class GreedyDecoder(TokenDecoder):
  192. def __init__(self, temperature: float, eot: int):
  193. self.temperature = temperature
  194. self.eot = eot
  195. def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
  196. if self.temperature == 0:
  197. next_tokens = logits.argmax(dim=-1)
  198. else:
  199. next_tokens = Categorical(logits=logits / self.temperature).sample()
  200. logprobs = F.log_softmax(logits.float(), dim=-1)
  201. current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
  202. sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
  203. next_tokens[tokens[:, -1] == self.eot] = self.eot
  204. tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
  205. completed = (tokens[:, -1] == self.eot).all()
  206. return tokens, completed
  207. def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
  208. # make sure each sequence has at least one EOT token at the end
  209. tokens = F.pad(tokens, (0, 1), value=self.eot)
  210. return tokens, sum_logprobs.tolist()
  211. class BeamSearchDecoder(TokenDecoder):
  212. def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
  213. self.beam_size = beam_size
  214. self.eot = eot
  215. self.inference = inference
  216. self.patience = patience or 1.0
  217. self.max_candidates: int = round(beam_size * self.patience)
  218. self.finished_sequences = None
  219. assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
  220. def reset(self):
  221. self.finished_sequences = None
  222. def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
  223. if tokens.shape[0] % self.beam_size != 0:
  224. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  225. n_audio = tokens.shape[0] // self.beam_size
  226. if self.finished_sequences is None: # for the first update
  227. self.finished_sequences = [{} for _ in range(n_audio)]
  228. logprobs = F.log_softmax(logits.float(), dim=-1)
  229. next_tokens, source_indices, finished_sequences = [], [], []
  230. for i in range(n_audio):
  231. scores, sources, finished = {}, {}, {}
  232. # STEP 1: calculate the cumulative log probabilities for possible candidates
  233. for j in range(self.beam_size):
  234. idx = i * self.beam_size + j
  235. prefix = tokens[idx].tolist()
  236. for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
  237. new_logprob = (sum_logprobs[idx] + logprob).item()
  238. sequence = tuple(prefix + [token.item()])
  239. scores[sequence] = new_logprob
  240. sources[sequence] = idx
  241. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  242. saved = 0
  243. for sequence in sorted(scores, key=scores.get, reverse=True):
  244. if sequence[-1] == self.eot:
  245. finished[sequence] = scores[sequence]
  246. else:
  247. sum_logprobs[len(next_tokens)] = scores[sequence]
  248. next_tokens.append(sequence)
  249. source_indices.append(sources[sequence])
  250. saved += 1
  251. if saved == self.beam_size:
  252. break
  253. finished_sequences.append(finished)
  254. tokens = torch.tensor(next_tokens, device=tokens.device)
  255. self.inference.rearrange_kv_cache(source_indices)
  256. # add newly finished sequences to self.finished_sequences
  257. assert len(self.finished_sequences) == len(finished_sequences)
  258. for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
  259. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  260. if len(previously_finished) >= self.max_candidates:
  261. break # the candidate list is full
  262. previously_finished[seq] = newly_finished[seq]
  263. # mark as completed if all audio has enough number of samples
  264. completed = all(
  265. len(sequences) >= self.max_candidates for sequences in self.finished_sequences
  266. )
  267. return tokens, completed
  268. def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
  269. # collect all finished sequences, including patience, and add unfinished ones if not enough
  270. sum_logprobs = sum_logprobs.cpu()
  271. for i, sequences in enumerate(self.finished_sequences):
  272. if len(sequences) < self.beam_size: # when not enough sequences are finished
  273. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  274. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  275. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  276. if len(sequences) >= self.beam_size:
  277. break
  278. tokens: List[List[Tensor]] = [
  279. [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
  280. ]
  281. sum_logprobs: List[List[float]] = [
  282. list(sequences.values()) for sequences in self.finished_sequences
  283. ]
  284. return tokens, sum_logprobs
  285. class LogitFilter:
  286. def apply(self, logits: Tensor, tokens: Tensor) -> None:
  287. """Apply any filtering or masking to logits in-place
  288. Parameters
  289. ----------
  290. logits : Tensor, shape = (n_batch, vocab_size)
  291. per-token logits of the probability distribution at the current step
  292. tokens : Tensor, shape = (n_batch, current_sequence_length)
  293. all tokens in the context so far, including the prefix and sot_sequence tokens
  294. """
  295. raise NotImplementedError
  296. class SuppressBlank(LogitFilter):
  297. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  298. self.tokenizer = tokenizer
  299. self.sample_begin = sample_begin
  300. def apply(self, logits: Tensor, tokens: Tensor):
  301. if tokens.shape[1] == self.sample_begin:
  302. logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
  303. class SuppressTokens(LogitFilter):
  304. def __init__(self, suppress_tokens: Sequence[int]):
  305. self.suppress_tokens = list(suppress_tokens)
  306. def apply(self, logits: Tensor, tokens: Tensor):
  307. logits[:, self.suppress_tokens] = -np.inf
  308. class ApplyTimestampRules(LogitFilter):
  309. def __init__(
  310. self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
  311. ):
  312. self.tokenizer = tokenizer
  313. self.sample_begin = sample_begin
  314. self.max_initial_timestamp_index = max_initial_timestamp_index
  315. def apply(self, logits: Tensor, tokens: Tensor):
  316. # suppress <|notimestamps|> which is handled by without_timestamps
  317. if self.tokenizer.no_timestamps is not None:
  318. logits[:, self.tokenizer.no_timestamps] = -np.inf
  319. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  320. for k in range(tokens.shape[0]):
  321. sampled_tokens = tokens[k, self.sample_begin :]
  322. seq = [t for t in sampled_tokens.tolist()]
  323. last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  324. penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  325. if last_was_timestamp:
  326. if penultimate_was_timestamp: # has to be non-timestamp
  327. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  328. else: # cannot be normal text tokens
  329. logits[k, : self.tokenizer.eot] = -np.inf
  330. timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
  331. if timestamps.numel() > 0:
  332. # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
  333. logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
  334. if tokens.shape[1] == self.sample_begin:
  335. # suppress generating non-timestamp tokens at the beginning
  336. logits[:, : self.tokenizer.timestamp_begin] = -np.inf
  337. # apply the `max_initial_timestamp` option
  338. if self.max_initial_timestamp_index is not None:
  339. last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  340. logits[:, last_allowed + 1 :] = -np.inf
  341. # if sum of probability over timestamps is above any other token, sample timestamp
  342. logprobs = F.log_softmax(logits.float(), dim=-1)
  343. for k in range(tokens.shape[0]):
  344. timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
  345. max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
  346. if timestamp_logprob > max_text_token_logprob:
  347. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  348. class DecodingTask:
  349. inference: Inference
  350. sequence_ranker: SequenceRanker
  351. decoder: TokenDecoder
  352. logit_filters: List[LogitFilter]
  353. def __init__(self, model: "Whisper", options: DecodingOptions):
  354. self.model = model
  355. language = options.language or "en"
  356. tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
  357. self.tokenizer: Tokenizer = tokenizer
  358. self.options: DecodingOptions = self._verify_options(options)
  359. self.n_group: int = options.beam_size or options.best_of or 1
  360. self.n_ctx: int = model.dims.n_text_ctx
  361. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  362. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  363. if self.options.without_timestamps:
  364. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  365. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  366. self.sample_begin: int = len(self.initial_tokens)
  367. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  368. # inference: implements the forward pass through the decoder, including kv caching
  369. self.inference = PyTorchInference(model, len(self.initial_tokens))
  370. # sequence ranker: implements how to rank a group of sampled sequences
  371. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  372. # decoder: implements how to select the next tokens, given the autoregressive distribution
  373. if options.beam_size is not None:
  374. self.decoder = BeamSearchDecoder(
  375. options.beam_size, tokenizer.eot, self.inference, options.patience
  376. )
  377. else:
  378. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  379. # logit filters: applies various rules to suppress or penalize certain tokens
  380. self.logit_filters = []
  381. if self.options.suppress_blank:
  382. self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
  383. if self.options.suppress_tokens:
  384. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  385. if not options.without_timestamps:
  386. precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
  387. max_initial_timestamp_index = None
  388. if options.max_initial_timestamp:
  389. max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
  390. self.logit_filters.append(
  391. ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
  392. )
  393. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  394. if options.beam_size is not None and options.best_of is not None:
  395. raise ValueError("beam_size and best_of can't be given together")
  396. if options.temperature == 0:
  397. if options.best_of is not None:
  398. raise ValueError("best_of with greedy sampling (T=0) is not compatible")
  399. if options.patience is not None and options.beam_size is None:
  400. raise ValueError("patience requires beam_size to be given")
  401. if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
  402. raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
  403. return options
  404. def _get_initial_tokens(self) -> Tuple[int]:
  405. tokens = list(self.sot_sequence)
  406. if prefix := self.options.prefix:
  407. prefix_tokens = (
  408. self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
  409. )
  410. if self.sample_len is not None:
  411. max_prefix_len = self.n_ctx // 2 - self.sample_len
  412. prefix_tokens = prefix_tokens[-max_prefix_len:]
  413. tokens = tokens + prefix_tokens
  414. if prompt := self.options.prompt:
  415. prompt_tokens = (
  416. self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
  417. )
  418. tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
  419. return tuple(tokens)
  420. def _get_suppress_tokens(self) -> Tuple[int]:
  421. suppress_tokens = self.options.suppress_tokens
  422. if isinstance(suppress_tokens, str):
  423. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  424. if -1 in suppress_tokens:
  425. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  426. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  427. elif suppress_tokens is None or len(suppress_tokens) == 0:
  428. suppress_tokens = [] # interpret empty string as an empty list
  429. else:
  430. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  431. suppress_tokens.extend(
  432. [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
  433. )
  434. if self.tokenizer.no_speech is not None:
  435. # no-speech probability is collected separately
  436. suppress_tokens.append(self.tokenizer.no_speech)
  437. return tuple(sorted(set(suppress_tokens)))
  438. def _get_audio_features(self, mel: Tensor):
  439. if self.options.fp16:
  440. mel = mel.half()
  441. if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
  442. # encoded audio features are given; skip audio encoding
  443. audio_features = mel
  444. else:
  445. audio_features = self.model.encoder(mel)
  446. if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
  447. return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
  448. return audio_features
  449. def _detect_language(self, audio_features: Tensor, tokens: Tensor):
  450. languages = [self.options.language] * audio_features.shape[0]
  451. lang_probs = None
  452. if self.options.language is None or self.options.task == "lang_id":
  453. lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
  454. languages = [max(probs, key=probs.get) for probs in lang_probs]
  455. if self.options.language is None:
  456. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  457. return languages, lang_probs
  458. def _main_loop(self, audio_features: Tensor, tokens: Tensor):
  459. assert audio_features.shape[0] == tokens.shape[0]
  460. n_batch = tokens.shape[0]
  461. sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
  462. no_speech_probs = [np.nan] * n_batch
  463. try:
  464. for i in range(self.sample_len):
  465. logits = self.inference.logits(tokens, audio_features)
  466. if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
  467. probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
  468. no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
  469. # now we need to consider the logits at the last token only
  470. logits = logits[:, -1]
  471. # apply the logit filters, e.g. for suppressing or applying penalty to
  472. for logit_filter in self.logit_filters:
  473. logit_filter.apply(logits, tokens)
  474. # expand the tokens tensor with the selected next tokens
  475. tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
  476. if completed or tokens.shape[-1] > self.n_ctx:
  477. break
  478. finally:
  479. self.inference.cleanup_caching()
  480. return tokens, sum_logprobs, no_speech_probs
  481. @torch.no_grad()
  482. def run(self, mel: Tensor) -> List[DecodingResult]:
  483. self.decoder.reset()
  484. tokenizer: Tokenizer = self.tokenizer
  485. n_audio: int = mel.shape[0]
  486. audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
  487. tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
  488. # detect language if requested, overwriting the language token
  489. languages, language_probs = self._detect_language(audio_features, tokens)
  490. if self.options.task == "lang_id":
  491. return [
  492. DecodingResult(audio_features=features, language=language, language_probs=probs)
  493. for features, language, probs in zip(audio_features, languages, language_probs)
  494. ]
  495. # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
  496. audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
  497. tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
  498. # call the main sampling loop
  499. tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  500. # reshape the tensors to have (n_audio, n_group) as the first two dimensions
  501. audio_features = audio_features[:: self.n_group]
  502. no_speech_probs = no_speech_probs[:: self.n_group]
  503. assert audio_features.shape[0] == len(no_speech_probs) == n_audio
  504. tokens = tokens.reshape(n_audio, self.n_group, -1)
  505. sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
  506. # get the final candidates for each group, and slice between the first sampled token and EOT
  507. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  508. tokens: List[List[Tensor]] = [
  509. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
  510. ]
  511. # select the top-ranked sample in each group
  512. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  513. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  514. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  515. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  516. avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
  517. fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
  518. if len(set(map(len, fields))) != 1:
  519. raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
  520. return [
  521. DecodingResult(
  522. audio_features=features,
  523. language=language,
  524. tokens=tokens,
  525. text=text,
  526. avg_logprob=avg_logprob,
  527. no_speech_prob=no_speech_prob,
  528. temperature=self.options.temperature,
  529. compression_ratio=compression_ratio(text),
  530. )
  531. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
  532. ]
  533. @torch.no_grad()
  534. def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
  535. """
  536. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  537. Parameters
  538. ----------
  539. model: Whisper
  540. the Whisper model instance
  541. mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
  542. A tensor containing the Mel spectrogram(s)
  543. options: DecodingOptions
  544. A dataclass that contains all necessary options for decoding 30-second segments
  545. Returns
  546. -------
  547. result: Union[DecodingResult, List[DecodingResult]]
  548. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  549. """
  550. if single := mel.ndim == 2:
  551. mel = mel.unsqueeze(0)
  552. result = DecodingTask(model, options).run(mel)
  553. return result[0] if single else result