decoding.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. from dataclasses import dataclass, field
  2. from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. from torch.distributions import Categorical
  8. from .audio import CHUNK_LENGTH
  9. from .tokenizer import Tokenizer, get_tokenizer
  10. from .utils import compression_ratio
  11. if TYPE_CHECKING:
  12. from .model import Whisper
  13. @torch.no_grad()
  14. def detect_language(
  15. model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
  16. ) -> Tuple[Tensor, List[dict]]:
  17. """
  18. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  19. of the most probable language tokens and the probability distribution over all language tokens.
  20. This is performed outside the main decode loop in order to not interfere with kv-caching.
  21. Returns
  22. -------
  23. language_tokens : Tensor, shape = (n_audio,)
  24. ids of the most probable language tokens, which appears after the startoftranscript token.
  25. language_probs : List[Dict[str, float]], length = n_audio
  26. list of dictionaries containing the probability distribution over all languages.
  27. """
  28. if tokenizer is None:
  29. tokenizer = get_tokenizer(model.is_multilingual)
  30. if (
  31. tokenizer.language is None
  32. or tokenizer.language_token not in tokenizer.sot_sequence
  33. ):
  34. raise ValueError(
  35. "This model doesn't have language tokens so it can't perform lang id"
  36. )
  37. single = mel.ndim == 2
  38. if single:
  39. mel = mel.unsqueeze(0)
  40. # skip encoder forward pass if already-encoded audio features were given
  41. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  42. mel = model.encoder(mel)
  43. # forward pass using a single token, startoftranscript
  44. n_audio = mel.shape[0]
  45. x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
  46. logits = model.logits(x, mel)[:, 0]
  47. # collect detected languages; suppress all non-language tokens
  48. mask = torch.ones(logits.shape[-1], dtype=torch.bool)
  49. mask[list(tokenizer.all_language_tokens)] = False
  50. logits[:, mask] = -np.inf
  51. language_tokens = logits.argmax(dim=-1)
  52. language_token_probs = logits.softmax(dim=-1).cpu()
  53. language_probs = [
  54. {
  55. c: language_token_probs[i, j].item()
  56. for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
  57. }
  58. for i in range(n_audio)
  59. ]
  60. if single:
  61. language_tokens = language_tokens[0]
  62. language_probs = language_probs[0]
  63. return language_tokens, language_probs
  64. @dataclass(frozen=True)
  65. class DecodingOptions:
  66. # whether to perform X->X "transcribe" or X->English "translate"
  67. task: str = "transcribe"
  68. # language that the audio is in; uses detected language if None
  69. language: Optional[str] = None
  70. # sampling-related options
  71. temperature: float = 0.0
  72. sample_len: Optional[int] = None # maximum number of tokens to sample
  73. best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
  74. beam_size: Optional[int] = None # number of beams in beam search, if t == 0
  75. patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
  76. # "alpha" in Google NMT, or None for length norm, when ranking generations
  77. # to select which to return among the beams or best-of-N samples
  78. length_penalty: Optional[float] = None
  79. # text or tokens to feed as the prompt or the prefix; for more info:
  80. # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
  81. prompt: Optional[Union[str, List[int]]] = None # for the previous context
  82. prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
  83. # list of tokens ids (or comma-separated token ids) to suppress
  84. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  85. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  86. suppress_blank: bool = True # this will suppress blank outputs
  87. # timestamp sampling options
  88. without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
  89. max_initial_timestamp: Optional[float] = 1.0
  90. # implementation details
  91. fp16: bool = True # use fp16 for most of the calculation
  92. @dataclass(frozen=True)
  93. class DecodingResult:
  94. audio_features: Tensor
  95. language: str
  96. language_probs: Optional[Dict[str, float]] = None
  97. tokens: List[int] = field(default_factory=list)
  98. text: str = ""
  99. avg_logprob: float = np.nan
  100. no_speech_prob: float = np.nan
  101. temperature: float = np.nan
  102. compression_ratio: float = np.nan
  103. class Inference:
  104. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  105. """Perform a forward pass on the decoder and return per-token logits"""
  106. raise NotImplementedError
  107. def rearrange_kv_cache(self, source_indices) -> None:
  108. """Update the key-value cache according to the updated beams"""
  109. raise NotImplementedError
  110. def cleanup_caching(self) -> None:
  111. """Clean up any resources or hooks after decoding is finished"""
  112. pass
  113. class PyTorchInference(Inference):
  114. def __init__(self, model: "Whisper", initial_token_length: int):
  115. self.model: "Whisper" = model
  116. self.initial_token_length = initial_token_length
  117. self.kv_cache = {}
  118. self.hooks = []
  119. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  120. if not self.kv_cache:
  121. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  122. if tokens.shape[-1] > self.initial_token_length:
  123. # only need to use the last token except in the first forward pass
  124. tokens = tokens[:, -1:]
  125. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  126. def cleanup_caching(self):
  127. for hook in self.hooks:
  128. hook.remove()
  129. self.kv_cache = {}
  130. self.hooks = []
  131. def rearrange_kv_cache(self, source_indices):
  132. for module, tensor in self.kv_cache.items():
  133. # update the key/value cache to contain the selected sequences
  134. self.kv_cache[module] = tensor[source_indices].detach()
  135. class SequenceRanker:
  136. def rank(
  137. self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
  138. ) -> List[int]:
  139. """
  140. Given a list of groups of samples and their cumulative log probabilities,
  141. return the indices of the samples in each group to select as the final result
  142. """
  143. raise NotImplementedError
  144. class MaximumLikelihoodRanker(SequenceRanker):
  145. """
  146. Select the sample with the highest log probabilities, penalized using either
  147. a simple length normalization or Google NMT paper's length penalty
  148. """
  149. def __init__(self, length_penalty: Optional[float]):
  150. self.length_penalty = length_penalty
  151. def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
  152. def scores(logprobs, lengths):
  153. result = []
  154. for logprob, length in zip(logprobs, lengths):
  155. if self.length_penalty is None:
  156. penalty = length
  157. else:
  158. # from the Google NMT paper
  159. penalty = ((5 + length) / 6) ** self.length_penalty
  160. result.append(logprob / penalty)
  161. return result
  162. # get the sequence with the highest score
  163. lengths = [[len(t) for t in s] for s in tokens]
  164. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  165. class TokenDecoder:
  166. def reset(self):
  167. """Initialize any stateful variables for decoding a new sequence"""
  168. def update(
  169. self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
  170. ) -> Tuple[Tensor, bool]:
  171. """Specify how to select the next token, based on the current trace and logits
  172. Parameters
  173. ----------
  174. tokens : Tensor, shape = (n_batch, current_sequence_length)
  175. all tokens in the context so far, including the prefix and sot_sequence tokens
  176. logits : Tensor, shape = (n_batch, vocab_size)
  177. per-token logits of the probability distribution at the current step
  178. sum_logprobs : Tensor, shape = (n_batch)
  179. cumulative log probabilities for each sequence
  180. Returns
  181. -------
  182. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  183. the tokens, appended with the selected next token
  184. completed : bool
  185. True if all sequences has reached the end of text
  186. """
  187. raise NotImplementedError
  188. def finalize(
  189. self, tokens: Tensor, sum_logprobs: Tensor
  190. ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
  191. """Finalize search and return the final candidate sequences
  192. Parameters
  193. ----------
  194. tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
  195. all tokens in the context so far, including the prefix and sot_sequence
  196. sum_logprobs : Tensor, shape = (n_audio, n_group)
  197. cumulative log probabilities for each sequence
  198. Returns
  199. -------
  200. tokens : Sequence[Sequence[Tensor]], length = n_audio
  201. sequence of Tensors containing candidate token sequences, for each audio input
  202. sum_logprobs : List[List[float]], length = n_audio
  203. sequence of cumulative log probabilities corresponding to the above
  204. """
  205. raise NotImplementedError
  206. class GreedyDecoder(TokenDecoder):
  207. def __init__(self, temperature: float, eot: int):
  208. self.temperature = temperature
  209. self.eot = eot
  210. def update(
  211. self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
  212. ) -> Tuple[Tensor, bool]:
  213. if self.temperature == 0:
  214. next_tokens = logits.argmax(dim=-1)
  215. else:
  216. next_tokens = Categorical(logits=logits / self.temperature).sample()
  217. logprobs = F.log_softmax(logits.float(), dim=-1)
  218. current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
  219. sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
  220. next_tokens[tokens[:, -1] == self.eot] = self.eot
  221. tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
  222. completed = (tokens[:, -1] == self.eot).all()
  223. return tokens, completed
  224. def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
  225. # make sure each sequence has at least one EOT token at the end
  226. tokens = F.pad(tokens, (0, 1), value=self.eot)
  227. return tokens, sum_logprobs.tolist()
  228. class BeamSearchDecoder(TokenDecoder):
  229. def __init__(
  230. self,
  231. beam_size: int,
  232. eot: int,
  233. inference: Inference,
  234. patience: Optional[float] = None,
  235. ):
  236. self.beam_size = beam_size
  237. self.eot = eot
  238. self.inference = inference
  239. self.patience = patience or 1.0
  240. self.max_candidates: int = round(beam_size * self.patience)
  241. self.finished_sequences = None
  242. assert (
  243. self.max_candidates > 0
  244. ), f"Invalid beam size ({beam_size}) or patience ({patience})"
  245. def reset(self):
  246. self.finished_sequences = None
  247. def update(
  248. self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
  249. ) -> Tuple[Tensor, bool]:
  250. if tokens.shape[0] % self.beam_size != 0:
  251. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  252. n_audio = tokens.shape[0] // self.beam_size
  253. if self.finished_sequences is None: # for the first update
  254. self.finished_sequences = [{} for _ in range(n_audio)]
  255. logprobs = F.log_softmax(logits.float(), dim=-1)
  256. next_tokens, source_indices, finished_sequences = [], [], []
  257. for i in range(n_audio):
  258. scores, sources, finished = {}, {}, {}
  259. # STEP 1: calculate the cumulative log probabilities for possible candidates
  260. for j in range(self.beam_size):
  261. idx = i * self.beam_size + j
  262. prefix = tokens[idx].tolist()
  263. for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
  264. new_logprob = (sum_logprobs[idx] + logprob).item()
  265. sequence = tuple(prefix + [token.item()])
  266. scores[sequence] = new_logprob
  267. sources[sequence] = idx
  268. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  269. saved = 0
  270. for sequence in sorted(scores, key=scores.get, reverse=True):
  271. if sequence[-1] == self.eot:
  272. finished[sequence] = scores[sequence]
  273. else:
  274. sum_logprobs[len(next_tokens)] = scores[sequence]
  275. next_tokens.append(sequence)
  276. source_indices.append(sources[sequence])
  277. saved += 1
  278. if saved == self.beam_size:
  279. break
  280. finished_sequences.append(finished)
  281. tokens = torch.tensor(next_tokens, device=tokens.device)
  282. self.inference.rearrange_kv_cache(source_indices)
  283. # add newly finished sequences to self.finished_sequences
  284. assert len(self.finished_sequences) == len(finished_sequences)
  285. for previously_finished, newly_finished in zip(
  286. self.finished_sequences, finished_sequences
  287. ):
  288. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  289. if len(previously_finished) >= self.max_candidates:
  290. break # the candidate list is full
  291. previously_finished[seq] = newly_finished[seq]
  292. # mark as completed if all audio has enough number of samples
  293. completed = all(
  294. len(sequences) >= self.max_candidates
  295. for sequences in self.finished_sequences
  296. )
  297. return tokens, completed
  298. def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
  299. # collect all finished sequences, including patience, and add unfinished ones if not enough
  300. sum_logprobs = sum_logprobs.cpu()
  301. for i, sequences in enumerate(self.finished_sequences):
  302. if (
  303. len(sequences) < self.beam_size
  304. ): # when not enough sequences are finished
  305. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  306. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  307. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  308. if len(sequences) >= self.beam_size:
  309. break
  310. tokens: List[List[Tensor]] = [
  311. [torch.tensor(seq) for seq in sequences.keys()]
  312. for sequences in self.finished_sequences
  313. ]
  314. sum_logprobs: List[List[float]] = [
  315. list(sequences.values()) for sequences in self.finished_sequences
  316. ]
  317. return tokens, sum_logprobs
  318. class LogitFilter:
  319. def apply(self, logits: Tensor, tokens: Tensor) -> None:
  320. """Apply any filtering or masking to logits in-place
  321. Parameters
  322. ----------
  323. logits : Tensor, shape = (n_batch, vocab_size)
  324. per-token logits of the probability distribution at the current step
  325. tokens : Tensor, shape = (n_batch, current_sequence_length)
  326. all tokens in the context so far, including the prefix and sot_sequence tokens
  327. """
  328. raise NotImplementedError
  329. class SuppressBlank(LogitFilter):
  330. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  331. self.tokenizer = tokenizer
  332. self.sample_begin = sample_begin
  333. def apply(self, logits: Tensor, tokens: Tensor):
  334. if tokens.shape[1] == self.sample_begin:
  335. logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
  336. class SuppressTokens(LogitFilter):
  337. def __init__(self, suppress_tokens: Sequence[int]):
  338. self.suppress_tokens = list(suppress_tokens)
  339. def apply(self, logits: Tensor, tokens: Tensor):
  340. logits[:, self.suppress_tokens] = -np.inf
  341. class ApplyTimestampRules(LogitFilter):
  342. def __init__(
  343. self,
  344. tokenizer: Tokenizer,
  345. sample_begin: int,
  346. max_initial_timestamp_index: Optional[int],
  347. ):
  348. self.tokenizer = tokenizer
  349. self.sample_begin = sample_begin
  350. self.max_initial_timestamp_index = max_initial_timestamp_index
  351. def apply(self, logits: Tensor, tokens: Tensor):
  352. # suppress <|notimestamps|> which is handled by without_timestamps
  353. if self.tokenizer.no_timestamps is not None:
  354. logits[:, self.tokenizer.no_timestamps] = -np.inf
  355. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  356. for k in range(tokens.shape[0]):
  357. sampled_tokens = tokens[k, self.sample_begin :]
  358. seq = [t for t in sampled_tokens.tolist()]
  359. last_was_timestamp = (
  360. len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  361. )
  362. penultimate_was_timestamp = (
  363. len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  364. )
  365. if last_was_timestamp:
  366. if penultimate_was_timestamp: # has to be non-timestamp
  367. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  368. else: # cannot be normal text tokens
  369. logits[k, : self.tokenizer.eot] = -np.inf
  370. timestamps = sampled_tokens[
  371. sampled_tokens.ge(self.tokenizer.timestamp_begin)
  372. ]
  373. if timestamps.numel() > 0:
  374. # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
  375. logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
  376. if tokens.shape[1] == self.sample_begin:
  377. # suppress generating non-timestamp tokens at the beginning
  378. logits[:, : self.tokenizer.timestamp_begin] = -np.inf
  379. # apply the `max_initial_timestamp` option
  380. if self.max_initial_timestamp_index is not None:
  381. last_allowed = (
  382. self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  383. )
  384. logits[:, last_allowed + 1 :] = -np.inf
  385. # if sum of probability over timestamps is above any other token, sample timestamp
  386. logprobs = F.log_softmax(logits.float(), dim=-1)
  387. for k in range(tokens.shape[0]):
  388. timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
  389. dim=-1
  390. )
  391. max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
  392. if timestamp_logprob > max_text_token_logprob:
  393. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  394. class DecodingTask:
  395. inference: Inference
  396. sequence_ranker: SequenceRanker
  397. decoder: TokenDecoder
  398. logit_filters: List[LogitFilter]
  399. def __init__(self, model: "Whisper", options: DecodingOptions):
  400. self.model = model
  401. language = options.language or "en"
  402. tokenizer = get_tokenizer(
  403. model.is_multilingual, language=language, task=options.task
  404. )
  405. self.tokenizer: Tokenizer = tokenizer
  406. self.options: DecodingOptions = self._verify_options(options)
  407. self.n_group: int = options.beam_size or options.best_of or 1
  408. self.n_ctx: int = model.dims.n_text_ctx
  409. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  410. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  411. if self.options.without_timestamps:
  412. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  413. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  414. self.sample_begin: int = len(self.initial_tokens)
  415. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  416. # inference: implements the forward pass through the decoder, including kv caching
  417. self.inference = PyTorchInference(model, len(self.initial_tokens))
  418. # sequence ranker: implements how to rank a group of sampled sequences
  419. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  420. # decoder: implements how to select the next tokens, given the autoregressive distribution
  421. if options.beam_size is not None:
  422. self.decoder = BeamSearchDecoder(
  423. options.beam_size, tokenizer.eot, self.inference, options.patience
  424. )
  425. else:
  426. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  427. # logit filters: applies various rules to suppress or penalize certain tokens
  428. self.logit_filters = []
  429. if self.options.suppress_blank:
  430. self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
  431. if self.options.suppress_tokens:
  432. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  433. if not options.without_timestamps:
  434. precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
  435. max_initial_timestamp_index = None
  436. if options.max_initial_timestamp:
  437. max_initial_timestamp_index = round(
  438. self.options.max_initial_timestamp / precision
  439. )
  440. self.logit_filters.append(
  441. ApplyTimestampRules(
  442. tokenizer, self.sample_begin, max_initial_timestamp_index
  443. )
  444. )
  445. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  446. if options.beam_size is not None and options.best_of is not None:
  447. raise ValueError("beam_size and best_of can't be given together")
  448. if options.temperature == 0:
  449. if options.best_of is not None:
  450. raise ValueError("best_of with greedy sampling (T=0) is not compatible")
  451. if options.patience is not None and options.beam_size is None:
  452. raise ValueError("patience requires beam_size to be given")
  453. if options.length_penalty is not None and not (
  454. 0 <= options.length_penalty <= 1
  455. ):
  456. raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
  457. return options
  458. def _get_initial_tokens(self) -> Tuple[int]:
  459. tokens = list(self.sot_sequence)
  460. if prefix := self.options.prefix:
  461. prefix_tokens = (
  462. self.tokenizer.encode(" " + prefix.strip())
  463. if isinstance(prefix, str)
  464. else prefix
  465. )
  466. if self.sample_len is not None:
  467. max_prefix_len = self.n_ctx // 2 - self.sample_len
  468. prefix_tokens = prefix_tokens[-max_prefix_len:]
  469. tokens = tokens + prefix_tokens
  470. if prompt := self.options.prompt:
  471. prompt_tokens = (
  472. self.tokenizer.encode(" " + prompt.strip())
  473. if isinstance(prompt, str)
  474. else prompt
  475. )
  476. tokens = (
  477. [self.tokenizer.sot_prev]
  478. + prompt_tokens[-(self.n_ctx // 2 - 1) :]
  479. + tokens
  480. )
  481. return tuple(tokens)
  482. def _get_suppress_tokens(self) -> Tuple[int]:
  483. suppress_tokens = self.options.suppress_tokens
  484. if isinstance(suppress_tokens, str):
  485. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  486. if -1 in suppress_tokens:
  487. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  488. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  489. elif suppress_tokens is None or len(suppress_tokens) == 0:
  490. suppress_tokens = [] # interpret empty string as an empty list
  491. else:
  492. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  493. suppress_tokens.extend(
  494. [
  495. self.tokenizer.transcribe,
  496. self.tokenizer.translate,
  497. self.tokenizer.sot,
  498. self.tokenizer.sot_prev,
  499. self.tokenizer.sot_lm,
  500. ]
  501. )
  502. if self.tokenizer.no_speech is not None:
  503. # no-speech probability is collected separately
  504. suppress_tokens.append(self.tokenizer.no_speech)
  505. return tuple(sorted(set(suppress_tokens)))
  506. def _get_audio_features(self, mel: Tensor):
  507. if self.options.fp16:
  508. mel = mel.half()
  509. if mel.shape[-2:] == (
  510. self.model.dims.n_audio_ctx,
  511. self.model.dims.n_audio_state,
  512. ):
  513. # encoded audio features are given; skip audio encoding
  514. audio_features = mel
  515. else:
  516. audio_features = self.model.encoder(mel)
  517. if audio_features.dtype != (
  518. torch.float16 if self.options.fp16 else torch.float32
  519. ):
  520. return TypeError(
  521. f"audio_features has an incorrect dtype: {audio_features.dtype}"
  522. )
  523. return audio_features
  524. def _detect_language(self, audio_features: Tensor, tokens: Tensor):
  525. languages = [self.options.language] * audio_features.shape[0]
  526. lang_probs = None
  527. if self.options.language is None or self.options.task == "lang_id":
  528. lang_tokens, lang_probs = self.model.detect_language(
  529. audio_features, self.tokenizer
  530. )
  531. languages = [max(probs, key=probs.get) for probs in lang_probs]
  532. if self.options.language is None:
  533. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  534. return languages, lang_probs
  535. def _main_loop(self, audio_features: Tensor, tokens: Tensor):
  536. assert audio_features.shape[0] == tokens.shape[0]
  537. n_batch = tokens.shape[0]
  538. sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
  539. no_speech_probs = [np.nan] * n_batch
  540. try:
  541. for i in range(self.sample_len):
  542. logits = self.inference.logits(tokens, audio_features)
  543. if (
  544. i == 0 and self.tokenizer.no_speech is not None
  545. ): # save no_speech_probs
  546. probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
  547. no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
  548. # now we need to consider the logits at the last token only
  549. logits = logits[:, -1]
  550. # apply the logit filters, e.g. for suppressing or applying penalty to
  551. for logit_filter in self.logit_filters:
  552. logit_filter.apply(logits, tokens)
  553. # expand the tokens tensor with the selected next tokens
  554. tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
  555. if completed or tokens.shape[-1] > self.n_ctx:
  556. break
  557. finally:
  558. self.inference.cleanup_caching()
  559. return tokens, sum_logprobs, no_speech_probs
  560. @torch.no_grad()
  561. def run(self, mel: Tensor) -> List[DecodingResult]:
  562. self.decoder.reset()
  563. tokenizer: Tokenizer = self.tokenizer
  564. n_audio: int = mel.shape[0]
  565. audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
  566. tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
  567. # detect language if requested, overwriting the language token
  568. languages, language_probs = self._detect_language(audio_features, tokens)
  569. if self.options.task == "lang_id":
  570. return [
  571. DecodingResult(
  572. audio_features=features, language=language, language_probs=probs
  573. )
  574. for features, language, probs in zip(
  575. audio_features, languages, language_probs
  576. )
  577. ]
  578. # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
  579. audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
  580. tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
  581. # call the main sampling loop
  582. tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  583. # reshape the tensors to have (n_audio, n_group) as the first two dimensions
  584. audio_features = audio_features[:: self.n_group]
  585. no_speech_probs = no_speech_probs[:: self.n_group]
  586. assert audio_features.shape[0] == len(no_speech_probs) == n_audio
  587. tokens = tokens.reshape(n_audio, self.n_group, -1)
  588. sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
  589. # get the final candidates for each group, and slice between the first sampled token and EOT
  590. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  591. tokens: List[List[Tensor]] = [
  592. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
  593. for s in tokens
  594. ]
  595. # select the top-ranked sample in each group
  596. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  597. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  598. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  599. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  600. avg_logprobs: List[float] = [
  601. lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
  602. ]
  603. fields = (
  604. texts,
  605. languages,
  606. tokens,
  607. audio_features,
  608. avg_logprobs,
  609. no_speech_probs,
  610. )
  611. if len(set(map(len, fields))) != 1:
  612. raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
  613. return [
  614. DecodingResult(
  615. audio_features=features,
  616. language=language,
  617. tokens=tokens,
  618. text=text,
  619. avg_logprob=avg_logprob,
  620. no_speech_prob=no_speech_prob,
  621. temperature=self.options.temperature,
  622. compression_ratio=compression_ratio(text),
  623. )
  624. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
  625. *fields
  626. )
  627. ]
  628. @torch.no_grad()
  629. def decode(
  630. model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
  631. ) -> Union[DecodingResult, List[DecodingResult]]:
  632. """
  633. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  634. Parameters
  635. ----------
  636. model: Whisper
  637. the Whisper model instance
  638. mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
  639. A tensor containing the Mel spectrogram(s)
  640. options: DecodingOptions
  641. A dataclass that contains all necessary options for decoding 30-second segments
  642. Returns
  643. -------
  644. result: Union[DecodingResult, List[DecodingResult]]
  645. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  646. """
  647. if single := mel.ndim == 2:
  648. mel = mel.unsqueeze(0)
  649. result = DecodingTask(model, options).run(mel)
  650. return result[0] if single else result