decoding.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826
  1. from dataclasses import dataclass, field, replace
  2. from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. from torch.distributions import Categorical
  8. from .audio import CHUNK_LENGTH
  9. from .tokenizer import Tokenizer, get_tokenizer
  10. from .utils import compression_ratio
  11. if TYPE_CHECKING:
  12. from .model import Whisper
  13. @torch.no_grad()
  14. def detect_language(
  15. model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
  16. ) -> Tuple[Tensor, List[dict]]:
  17. """
  18. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  19. of the most probable language tokens and the probability distribution over all language tokens.
  20. This is performed outside the main decode loop in order to not interfere with kv-caching.
  21. Returns
  22. -------
  23. language_tokens : Tensor, shape = (n_audio,)
  24. ids of the most probable language tokens, which appears after the startoftranscript token.
  25. language_probs : List[Dict[str, float]], length = n_audio
  26. list of dictionaries containing the probability distribution over all languages.
  27. """
  28. if tokenizer is None:
  29. tokenizer = get_tokenizer(
  30. model.is_multilingual, num_languages=model.num_languages
  31. )
  32. if (
  33. tokenizer.language is None
  34. or tokenizer.language_token not in tokenizer.sot_sequence
  35. ):
  36. raise ValueError(
  37. "This model doesn't have language tokens so it can't perform lang id"
  38. )
  39. single = mel.ndim == 2
  40. if single:
  41. mel = mel.unsqueeze(0)
  42. # skip encoder forward pass if already-encoded audio features were given
  43. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  44. mel = model.encoder(mel)
  45. # forward pass using a single token, startoftranscript
  46. n_audio = mel.shape[0]
  47. x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
  48. logits = model.logits(x, mel)[:, 0]
  49. # collect detected languages; suppress all non-language tokens
  50. mask = torch.ones(logits.shape[-1], dtype=torch.bool)
  51. mask[list(tokenizer.all_language_tokens)] = False
  52. logits[:, mask] = -np.inf
  53. language_tokens = logits.argmax(dim=-1)
  54. language_token_probs = logits.softmax(dim=-1).cpu()
  55. language_probs = [
  56. {
  57. c: language_token_probs[i, j].item()
  58. for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
  59. }
  60. for i in range(n_audio)
  61. ]
  62. if single:
  63. language_tokens = language_tokens[0]
  64. language_probs = language_probs[0]
  65. return language_tokens, language_probs
  66. @dataclass(frozen=True)
  67. class DecodingOptions:
  68. # whether to perform X->X "transcribe" or X->English "translate"
  69. task: str = "transcribe"
  70. # language that the audio is in; uses detected language if None
  71. language: Optional[str] = None
  72. # sampling-related options
  73. temperature: float = 0.0
  74. sample_len: Optional[int] = None # maximum number of tokens to sample
  75. best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
  76. beam_size: Optional[int] = None # number of beams in beam search, if t == 0
  77. patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
  78. # "alpha" in Google NMT, or None for length norm, when ranking generations
  79. # to select which to return among the beams or best-of-N samples
  80. length_penalty: Optional[float] = None
  81. # text or tokens to feed as the prompt or the prefix; for more info:
  82. # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
  83. prompt: Optional[Union[str, List[int]]] = None # for the previous context
  84. prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
  85. # list of tokens ids (or comma-separated token ids) to suppress
  86. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  87. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  88. suppress_blank: bool = True # this will suppress blank outputs
  89. # timestamp sampling options
  90. without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
  91. max_initial_timestamp: Optional[float] = 1.0
  92. # implementation details
  93. fp16: bool = True # use fp16 for most of the calculation
  94. @dataclass(frozen=True)
  95. class DecodingResult:
  96. audio_features: Tensor
  97. language: str
  98. language_probs: Optional[Dict[str, float]] = None
  99. tokens: List[int] = field(default_factory=list)
  100. text: str = ""
  101. avg_logprob: float = np.nan
  102. no_speech_prob: float = np.nan
  103. temperature: float = np.nan
  104. compression_ratio: float = np.nan
  105. class Inference:
  106. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  107. """Perform a forward pass on the decoder and return per-token logits"""
  108. raise NotImplementedError
  109. def rearrange_kv_cache(self, source_indices) -> None:
  110. """Update the key-value cache according to the updated beams"""
  111. raise NotImplementedError
  112. def cleanup_caching(self) -> None:
  113. """Clean up any resources or hooks after decoding is finished"""
  114. pass
  115. class PyTorchInference(Inference):
  116. def __init__(self, model: "Whisper", initial_token_length: int):
  117. self.model: "Whisper" = model
  118. self.initial_token_length = initial_token_length
  119. self.kv_cache = {}
  120. self.hooks = []
  121. key_modules = [block.attn.key for block in self.model.decoder.blocks]
  122. value_modules = [block.attn.value for block in self.model.decoder.blocks]
  123. self.kv_modules = key_modules + value_modules
  124. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  125. if not self.kv_cache:
  126. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  127. if tokens.shape[-1] > self.initial_token_length:
  128. # only need to use the last token except in the first forward pass
  129. tokens = tokens[:, -1:]
  130. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  131. def cleanup_caching(self):
  132. for hook in self.hooks:
  133. hook.remove()
  134. self.kv_cache = {}
  135. self.hooks = []
  136. def rearrange_kv_cache(self, source_indices):
  137. if source_indices != list(range(len(source_indices))):
  138. for module in self.kv_modules:
  139. # update the key/value cache to contain the selected sequences
  140. self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
  141. class SequenceRanker:
  142. def rank(
  143. self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
  144. ) -> List[int]:
  145. """
  146. Given a list of groups of samples and their cumulative log probabilities,
  147. return the indices of the samples in each group to select as the final result
  148. """
  149. raise NotImplementedError
  150. class MaximumLikelihoodRanker(SequenceRanker):
  151. """
  152. Select the sample with the highest log probabilities, penalized using either
  153. a simple length normalization or Google NMT paper's length penalty
  154. """
  155. def __init__(self, length_penalty: Optional[float]):
  156. self.length_penalty = length_penalty
  157. def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
  158. def scores(logprobs, lengths):
  159. result = []
  160. for logprob, length in zip(logprobs, lengths):
  161. if self.length_penalty is None:
  162. penalty = length
  163. else:
  164. # from the Google NMT paper
  165. penalty = ((5 + length) / 6) ** self.length_penalty
  166. result.append(logprob / penalty)
  167. return result
  168. # get the sequence with the highest score
  169. lengths = [[len(t) for t in s] for s in tokens]
  170. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  171. class TokenDecoder:
  172. def reset(self):
  173. """Initialize any stateful variables for decoding a new sequence"""
  174. def update(
  175. self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
  176. ) -> Tuple[Tensor, bool]:
  177. """Specify how to select the next token, based on the current trace and logits
  178. Parameters
  179. ----------
  180. tokens : Tensor, shape = (n_batch, current_sequence_length)
  181. all tokens in the context so far, including the prefix and sot_sequence tokens
  182. logits : Tensor, shape = (n_batch, vocab_size)
  183. per-token logits of the probability distribution at the current step
  184. sum_logprobs : Tensor, shape = (n_batch)
  185. cumulative log probabilities for each sequence
  186. Returns
  187. -------
  188. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  189. the tokens, appended with the selected next token
  190. completed : bool
  191. True if all sequences has reached the end of text
  192. """
  193. raise NotImplementedError
  194. def finalize(
  195. self, tokens: Tensor, sum_logprobs: Tensor
  196. ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
  197. """Finalize search and return the final candidate sequences
  198. Parameters
  199. ----------
  200. tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
  201. all tokens in the context so far, including the prefix and sot_sequence
  202. sum_logprobs : Tensor, shape = (n_audio, n_group)
  203. cumulative log probabilities for each sequence
  204. Returns
  205. -------
  206. tokens : Sequence[Sequence[Tensor]], length = n_audio
  207. sequence of Tensors containing candidate token sequences, for each audio input
  208. sum_logprobs : List[List[float]], length = n_audio
  209. sequence of cumulative log probabilities corresponding to the above
  210. """
  211. raise NotImplementedError
  212. class GreedyDecoder(TokenDecoder):
  213. def __init__(self, temperature: float, eot: int):
  214. self.temperature = temperature
  215. self.eot = eot
  216. def update(
  217. self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
  218. ) -> Tuple[Tensor, bool]:
  219. if self.temperature == 0:
  220. next_tokens = logits.argmax(dim=-1)
  221. else:
  222. next_tokens = Categorical(logits=logits / self.temperature).sample()
  223. logprobs = F.log_softmax(logits.float(), dim=-1)
  224. current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
  225. sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
  226. next_tokens[tokens[:, -1] == self.eot] = self.eot
  227. tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
  228. completed = (tokens[:, -1] == self.eot).all()
  229. return tokens, completed
  230. def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
  231. # make sure each sequence has at least one EOT token at the end
  232. tokens = F.pad(tokens, (0, 1), value=self.eot)
  233. return tokens, sum_logprobs.tolist()
  234. class BeamSearchDecoder(TokenDecoder):
  235. def __init__(
  236. self,
  237. beam_size: int,
  238. eot: int,
  239. inference: Inference,
  240. patience: Optional[float] = None,
  241. ):
  242. self.beam_size = beam_size
  243. self.eot = eot
  244. self.inference = inference
  245. self.patience = patience or 1.0
  246. self.max_candidates: int = round(beam_size * self.patience)
  247. self.finished_sequences = None
  248. assert (
  249. self.max_candidates > 0
  250. ), f"Invalid beam size ({beam_size}) or patience ({patience})"
  251. def reset(self):
  252. self.finished_sequences = None
  253. def update(
  254. self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
  255. ) -> Tuple[Tensor, bool]:
  256. if tokens.shape[0] % self.beam_size != 0:
  257. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  258. n_audio = tokens.shape[0] // self.beam_size
  259. if self.finished_sequences is None: # for the first update
  260. self.finished_sequences = [{} for _ in range(n_audio)]
  261. logprobs = F.log_softmax(logits.float(), dim=-1)
  262. next_tokens, source_indices, finished_sequences = [], [], []
  263. for i in range(n_audio):
  264. scores, sources, finished = {}, {}, {}
  265. # STEP 1: calculate the cumulative log probabilities for possible candidates
  266. for j in range(self.beam_size):
  267. idx = i * self.beam_size + j
  268. prefix = tokens[idx].tolist()
  269. for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
  270. new_logprob = (sum_logprobs[idx] + logprob).item()
  271. sequence = tuple(prefix + [token.item()])
  272. scores[sequence] = new_logprob
  273. sources[sequence] = idx
  274. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  275. saved = 0
  276. for sequence in sorted(scores, key=scores.get, reverse=True):
  277. if sequence[-1] == self.eot:
  278. finished[sequence] = scores[sequence]
  279. else:
  280. sum_logprobs[len(next_tokens)] = scores[sequence]
  281. next_tokens.append(sequence)
  282. source_indices.append(sources[sequence])
  283. saved += 1
  284. if saved == self.beam_size:
  285. break
  286. finished_sequences.append(finished)
  287. tokens = torch.tensor(next_tokens, device=tokens.device)
  288. self.inference.rearrange_kv_cache(source_indices)
  289. # add newly finished sequences to self.finished_sequences
  290. assert len(self.finished_sequences) == len(finished_sequences)
  291. for previously_finished, newly_finished in zip(
  292. self.finished_sequences, finished_sequences
  293. ):
  294. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  295. if len(previously_finished) >= self.max_candidates:
  296. break # the candidate list is full
  297. previously_finished[seq] = newly_finished[seq]
  298. # mark as completed if all audio has enough number of samples
  299. completed = all(
  300. len(sequences) >= self.max_candidates
  301. for sequences in self.finished_sequences
  302. )
  303. return tokens, completed
  304. def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
  305. # collect all finished sequences, including patience, and add unfinished ones if not enough
  306. sum_logprobs = sum_logprobs.cpu()
  307. for i, sequences in enumerate(self.finished_sequences):
  308. if (
  309. len(sequences) < self.beam_size
  310. ): # when not enough sequences are finished
  311. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  312. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  313. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  314. if len(sequences) >= self.beam_size:
  315. break
  316. tokens: List[List[Tensor]] = [
  317. [torch.tensor(seq) for seq in sequences.keys()]
  318. for sequences in self.finished_sequences
  319. ]
  320. sum_logprobs: List[List[float]] = [
  321. list(sequences.values()) for sequences in self.finished_sequences
  322. ]
  323. return tokens, sum_logprobs
  324. class LogitFilter:
  325. def apply(self, logits: Tensor, tokens: Tensor) -> None:
  326. """Apply any filtering or masking to logits in-place
  327. Parameters
  328. ----------
  329. logits : Tensor, shape = (n_batch, vocab_size)
  330. per-token logits of the probability distribution at the current step
  331. tokens : Tensor, shape = (n_batch, current_sequence_length)
  332. all tokens in the context so far, including the prefix and sot_sequence tokens
  333. """
  334. raise NotImplementedError
  335. class SuppressBlank(LogitFilter):
  336. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  337. self.tokenizer = tokenizer
  338. self.sample_begin = sample_begin
  339. def apply(self, logits: Tensor, tokens: Tensor):
  340. if tokens.shape[1] == self.sample_begin:
  341. logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
  342. class SuppressTokens(LogitFilter):
  343. def __init__(self, suppress_tokens: Sequence[int]):
  344. self.suppress_tokens = list(suppress_tokens)
  345. def apply(self, logits: Tensor, tokens: Tensor):
  346. logits[:, self.suppress_tokens] = -np.inf
  347. class ApplyTimestampRules(LogitFilter):
  348. def __init__(
  349. self,
  350. tokenizer: Tokenizer,
  351. sample_begin: int,
  352. max_initial_timestamp_index: Optional[int],
  353. ):
  354. self.tokenizer = tokenizer
  355. self.sample_begin = sample_begin
  356. self.max_initial_timestamp_index = max_initial_timestamp_index
  357. def apply(self, logits: Tensor, tokens: Tensor):
  358. # suppress <|notimestamps|> which is handled by without_timestamps
  359. if self.tokenizer.no_timestamps is not None:
  360. logits[:, self.tokenizer.no_timestamps] = -np.inf
  361. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  362. for k in range(tokens.shape[0]):
  363. sampled_tokens = tokens[k, self.sample_begin :]
  364. seq = [t for t in sampled_tokens.tolist()]
  365. last_was_timestamp = (
  366. len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  367. )
  368. penultimate_was_timestamp = (
  369. len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  370. )
  371. if last_was_timestamp:
  372. if penultimate_was_timestamp: # has to be non-timestamp
  373. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  374. else: # cannot be normal text tokens
  375. logits[k, : self.tokenizer.eot] = -np.inf
  376. timestamps = sampled_tokens[
  377. sampled_tokens.ge(self.tokenizer.timestamp_begin)
  378. ]
  379. if timestamps.numel() > 0:
  380. # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
  381. # also force each segment to have a nonzero length, to prevent infinite looping
  382. if last_was_timestamp and not penultimate_was_timestamp:
  383. timestamp_last = timestamps[-1]
  384. else:
  385. timestamp_last = timestamps[-1] + 1
  386. logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
  387. if tokens.shape[1] == self.sample_begin:
  388. # suppress generating non-timestamp tokens at the beginning
  389. logits[:, : self.tokenizer.timestamp_begin] = -np.inf
  390. # apply the `max_initial_timestamp` option
  391. if self.max_initial_timestamp_index is not None:
  392. last_allowed = (
  393. self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  394. )
  395. logits[:, last_allowed + 1 :] = -np.inf
  396. # if sum of probability over timestamps is above any other token, sample timestamp
  397. logprobs = F.log_softmax(logits.float(), dim=-1)
  398. for k in range(tokens.shape[0]):
  399. timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
  400. dim=-1
  401. )
  402. max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
  403. if timestamp_logprob > max_text_token_logprob:
  404. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  405. class DecodingTask:
  406. inference: Inference
  407. sequence_ranker: SequenceRanker
  408. decoder: TokenDecoder
  409. logit_filters: List[LogitFilter]
  410. def __init__(self, model: "Whisper", options: DecodingOptions):
  411. self.model = model
  412. language = options.language or "en"
  413. tokenizer = get_tokenizer(
  414. model.is_multilingual,
  415. num_languages=model.num_languages,
  416. language=language,
  417. task=options.task,
  418. )
  419. self.tokenizer: Tokenizer = tokenizer
  420. self.options: DecodingOptions = self._verify_options(options)
  421. self.n_group: int = options.beam_size or options.best_of or 1
  422. self.n_ctx: int = model.dims.n_text_ctx
  423. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  424. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  425. if self.options.without_timestamps:
  426. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  427. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  428. self.sample_begin: int = len(self.initial_tokens)
  429. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  430. # inference: implements the forward pass through the decoder, including kv caching
  431. self.inference = PyTorchInference(model, len(self.initial_tokens))
  432. # sequence ranker: implements how to rank a group of sampled sequences
  433. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  434. # decoder: implements how to select the next tokens, given the autoregressive distribution
  435. if options.beam_size is not None:
  436. self.decoder = BeamSearchDecoder(
  437. options.beam_size, tokenizer.eot, self.inference, options.patience
  438. )
  439. else:
  440. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  441. # logit filters: applies various rules to suppress or penalize certain tokens
  442. self.logit_filters = []
  443. if self.options.suppress_blank:
  444. self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
  445. if self.options.suppress_tokens:
  446. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  447. if not options.without_timestamps:
  448. precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
  449. max_initial_timestamp_index = None
  450. if options.max_initial_timestamp:
  451. max_initial_timestamp_index = round(
  452. self.options.max_initial_timestamp / precision
  453. )
  454. self.logit_filters.append(
  455. ApplyTimestampRules(
  456. tokenizer, self.sample_begin, max_initial_timestamp_index
  457. )
  458. )
  459. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  460. if options.beam_size is not None and options.best_of is not None:
  461. raise ValueError("beam_size and best_of can't be given together")
  462. if options.temperature == 0:
  463. if options.best_of is not None:
  464. raise ValueError("best_of with greedy sampling (T=0) is not compatible")
  465. if options.patience is not None and options.beam_size is None:
  466. raise ValueError("patience requires beam_size to be given")
  467. if options.length_penalty is not None and not (
  468. 0 <= options.length_penalty <= 1
  469. ):
  470. raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
  471. return options
  472. def _get_initial_tokens(self) -> Tuple[int]:
  473. tokens = list(self.sot_sequence)
  474. if prefix := self.options.prefix:
  475. prefix_tokens = (
  476. self.tokenizer.encode(" " + prefix.strip())
  477. if isinstance(prefix, str)
  478. else prefix
  479. )
  480. if self.sample_len is not None:
  481. max_prefix_len = self.n_ctx // 2 - self.sample_len
  482. prefix_tokens = prefix_tokens[-max_prefix_len:]
  483. tokens = tokens + prefix_tokens
  484. if prompt := self.options.prompt:
  485. prompt_tokens = (
  486. self.tokenizer.encode(" " + prompt.strip())
  487. if isinstance(prompt, str)
  488. else prompt
  489. )
  490. tokens = (
  491. [self.tokenizer.sot_prev]
  492. + prompt_tokens[-(self.n_ctx // 2 - 1) :]
  493. + tokens
  494. )
  495. return tuple(tokens)
  496. def _get_suppress_tokens(self) -> Tuple[int]:
  497. suppress_tokens = self.options.suppress_tokens
  498. if isinstance(suppress_tokens, str):
  499. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  500. if -1 in suppress_tokens:
  501. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  502. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  503. elif suppress_tokens is None or len(suppress_tokens) == 0:
  504. suppress_tokens = [] # interpret empty string as an empty list
  505. else:
  506. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  507. suppress_tokens.extend(
  508. [
  509. self.tokenizer.transcribe,
  510. self.tokenizer.translate,
  511. self.tokenizer.sot,
  512. self.tokenizer.sot_prev,
  513. self.tokenizer.sot_lm,
  514. ]
  515. )
  516. if self.tokenizer.no_speech is not None:
  517. # no-speech probability is collected separately
  518. suppress_tokens.append(self.tokenizer.no_speech)
  519. return tuple(sorted(set(suppress_tokens)))
  520. def _get_audio_features(self, mel: Tensor):
  521. if self.options.fp16:
  522. mel = mel.half()
  523. if mel.shape[-2:] == (
  524. self.model.dims.n_audio_ctx,
  525. self.model.dims.n_audio_state,
  526. ):
  527. # encoded audio features are given; skip audio encoding
  528. audio_features = mel
  529. else:
  530. audio_features = self.model.encoder(mel)
  531. if audio_features.dtype != (
  532. torch.float16 if self.options.fp16 else torch.float32
  533. ):
  534. return TypeError(
  535. f"audio_features has an incorrect dtype: {audio_features.dtype}"
  536. )
  537. return audio_features
  538. def _detect_language(self, audio_features: Tensor, tokens: Tensor):
  539. languages = [self.options.language] * audio_features.shape[0]
  540. lang_probs = None
  541. if self.options.language is None or self.options.task == "lang_id":
  542. lang_tokens, lang_probs = self.model.detect_language(
  543. audio_features, self.tokenizer
  544. )
  545. languages = [max(probs, key=probs.get) for probs in lang_probs]
  546. if self.options.language is None:
  547. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  548. return languages, lang_probs
  549. def _main_loop(self, audio_features: Tensor, tokens: Tensor):
  550. n_batch = tokens.shape[0]
  551. sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
  552. no_speech_probs = [np.nan] * n_batch
  553. try:
  554. for i in range(self.sample_len):
  555. logits = self.inference.logits(tokens, audio_features)
  556. if (
  557. i == 0 and self.tokenizer.no_speech is not None
  558. ): # save no_speech_probs
  559. probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
  560. no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
  561. # now we need to consider the logits at the last token only
  562. logits = logits[:, -1]
  563. # apply the logit filters, e.g. for suppressing or applying penalty to
  564. for logit_filter in self.logit_filters:
  565. logit_filter.apply(logits, tokens)
  566. # expand the tokens tensor with the selected next tokens
  567. tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
  568. if completed or tokens.shape[-1] > self.n_ctx:
  569. break
  570. finally:
  571. self.inference.cleanup_caching()
  572. return tokens, sum_logprobs, no_speech_probs
  573. @torch.no_grad()
  574. def run(self, mel: Tensor) -> List[DecodingResult]:
  575. self.decoder.reset()
  576. tokenizer: Tokenizer = self.tokenizer
  577. n_audio: int = mel.shape[0]
  578. audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
  579. tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
  580. # detect language if requested, overwriting the language token
  581. languages, language_probs = self._detect_language(audio_features, tokens)
  582. if self.options.task == "lang_id":
  583. return [
  584. DecodingResult(
  585. audio_features=features, language=language, language_probs=probs
  586. )
  587. for features, language, probs in zip(
  588. audio_features, languages, language_probs
  589. )
  590. ]
  591. # repeat text tensors by the group size, for beam search or best-of-n sampling
  592. tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
  593. # call the main sampling loop
  594. tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  595. # reshape the tensors to have (n_audio, n_group) as the first two dimensions
  596. audio_features = audio_features[:: self.n_group]
  597. no_speech_probs = no_speech_probs[:: self.n_group]
  598. assert audio_features.shape[0] == len(no_speech_probs) == n_audio
  599. tokens = tokens.reshape(n_audio, self.n_group, -1)
  600. sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
  601. # get the final candidates for each group, and slice between the first sampled token and EOT
  602. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  603. tokens: List[List[Tensor]] = [
  604. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
  605. for s in tokens
  606. ]
  607. # select the top-ranked sample in each group
  608. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  609. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  610. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  611. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  612. avg_logprobs: List[float] = [
  613. lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
  614. ]
  615. fields = (
  616. texts,
  617. languages,
  618. tokens,
  619. audio_features,
  620. avg_logprobs,
  621. no_speech_probs,
  622. )
  623. if len(set(map(len, fields))) != 1:
  624. raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
  625. return [
  626. DecodingResult(
  627. audio_features=features,
  628. language=language,
  629. tokens=tokens,
  630. text=text,
  631. avg_logprob=avg_logprob,
  632. no_speech_prob=no_speech_prob,
  633. temperature=self.options.temperature,
  634. compression_ratio=compression_ratio(text),
  635. )
  636. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
  637. *fields
  638. )
  639. ]
  640. @torch.no_grad()
  641. def decode(
  642. model: "Whisper",
  643. mel: Tensor,
  644. options: DecodingOptions = DecodingOptions(),
  645. **kwargs,
  646. ) -> Union[DecodingResult, List[DecodingResult]]:
  647. """
  648. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  649. Parameters
  650. ----------
  651. model: Whisper
  652. the Whisper model instance
  653. mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
  654. A tensor containing the Mel spectrogram(s)
  655. options: DecodingOptions
  656. A dataclass that contains all necessary options for decoding 30-second segments
  657. Returns
  658. -------
  659. result: Union[DecodingResult, List[DecodingResult]]
  660. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  661. """
  662. if single := mel.ndim == 2:
  663. mel = mel.unsqueeze(0)
  664. if kwargs:
  665. options = replace(options, **kwargs)
  666. result = DecodingTask(model, options).run(mel)
  667. return result[0] if single else result