|
@@ -1,5 +1,5 @@
|
|
|
from dataclasses import dataclass, field
|
|
|
-from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
|
|
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
-def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
|
|
+def detect_language(
|
|
|
+ model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
|
|
+) -> Tuple[Tensor, List[dict]]:
|
|
|
"""
|
|
|
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
|
|
of the most probable language tokens and the probability distribution over all language tokens.
|
|
@@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
|
|
|
"""
|
|
|
if tokenizer is None:
|
|
|
tokenizer = get_tokenizer(model.is_multilingual)
|
|
|
- if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
|
|
- raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
|
|
+ if (
|
|
|
+ tokenizer.language is None
|
|
|
+ or tokenizer.language_token not in tokenizer.sot_sequence
|
|
|
+ ):
|
|
|
+ raise ValueError(
|
|
|
+ "This model doesn't have language tokens so it can't perform lang id"
|
|
|
+ )
|
|
|
|
|
|
single = mel.ndim == 2
|
|
|
if single:
|
|
@@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class DecodingOptions:
|
|
|
- task: str = "transcribe"
|
|
|
- language: Optional[str] = None
|
|
|
+
|
|
|
+ task: str = "transcribe"
|
|
|
+
|
|
|
+
|
|
|
+ language: Optional[str] = None
|
|
|
|
|
|
|
|
|
temperature: float = 0.0
|
|
|
sample_len: Optional[int] = None
|
|
|
- best_of: Optional[int] = None
|
|
|
- beam_size: Optional[int] = None
|
|
|
- patience: Optional[float] = None
|
|
|
+ best_of: Optional[int] = None
|
|
|
+ beam_size: Optional[int] = None
|
|
|
+ patience: Optional[float] = None
|
|
|
|
|
|
-
|
|
|
- length_penalty: Optional[float] = None
|
|
|
+
|
|
|
+
|
|
|
+ length_penalty: Optional[float] = None
|
|
|
|
|
|
-
|
|
|
- prompt: Optional[Union[str, List[int]]] = None
|
|
|
- prefix: Optional[Union[str, List[int]]] = None
|
|
|
- suppress_blank: bool = True
|
|
|
+
|
|
|
+
|
|
|
+ prompt: Optional[Union[str, List[int]]] = None
|
|
|
+ prefix: Optional[Union[str, List[int]]] = None
|
|
|
|
|
|
|
|
|
|
|
|
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
|
|
+ suppress_blank: bool = True
|
|
|
|
|
|
|
|
|
- without_timestamps: bool = False
|
|
|
- max_initial_timestamp: Optional[float] = 1.0
|
|
|
+ without_timestamps: bool = False
|
|
|
+ max_initial_timestamp: Optional[float] = 1.0
|
|
|
|
|
|
|
|
|
fp16: bool = True
|
|
@@ -158,7 +170,9 @@ class PyTorchInference(Inference):
|
|
|
|
|
|
|
|
|
class SequenceRanker:
|
|
|
- def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
|
|
+ def rank(
|
|
|
+ self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
|
|
+ ) -> List[int]:
|
|
|
"""
|
|
|
Given a list of groups of samples and their cumulative log probabilities,
|
|
|
return the indices of the samples in each group to select as the final result
|
|
@@ -196,7 +210,9 @@ class TokenDecoder:
|
|
|
def reset(self):
|
|
|
"""Initialize any stateful variables for decoding a new sequence"""
|
|
|
|
|
|
- def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
|
|
+ def update(
|
|
|
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
|
|
+ ) -> Tuple[Tensor, bool]:
|
|
|
"""Specify how to select the next token, based on the current trace and logits
|
|
|
|
|
|
Parameters
|
|
@@ -251,7 +267,9 @@ class GreedyDecoder(TokenDecoder):
|
|
|
self.temperature = temperature
|
|
|
self.eot = eot
|
|
|
|
|
|
- def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
|
|
+ def update(
|
|
|
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
|
|
+ ) -> Tuple[Tensor, bool]:
|
|
|
if self.temperature == 0:
|
|
|
next_tokens = logits.argmax(dim=-1)
|
|
|
else:
|
|
@@ -274,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
|
|
|
|
|
|
|
|
|
class BeamSearchDecoder(TokenDecoder):
|
|
|
- def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ beam_size: int,
|
|
|
+ eot: int,
|
|
|
+ inference: Inference,
|
|
|
+ patience: Optional[float] = None,
|
|
|
+ ):
|
|
|
self.beam_size = beam_size
|
|
|
self.eot = eot
|
|
|
self.inference = inference
|
|
@@ -282,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
|
|
|
self.max_candidates: int = round(beam_size * self.patience)
|
|
|
self.finished_sequences = None
|
|
|
|
|
|
- assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
|
|
+ assert (
|
|
|
+ self.max_candidates > 0
|
|
|
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
|
|
|
|
|
def reset(self):
|
|
|
self.finished_sequences = None
|
|
|
|
|
|
- def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
|
|
+ def update(
|
|
|
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
|
|
+ ) -> Tuple[Tensor, bool]:
|
|
|
if tokens.shape[0] % self.beam_size != 0:
|
|
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
|
|
|
|
@@ -331,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
|
|
|
|
|
|
|
|
|
assert len(self.finished_sequences) == len(finished_sequences)
|
|
|
- for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
|
|
+ for previously_finished, newly_finished in zip(
|
|
|
+ self.finished_sequences, finished_sequences
|
|
|
+ ):
|
|
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
|
|
if len(previously_finished) >= self.max_candidates:
|
|
|
break
|
|
@@ -339,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
|
|
|
|
|
|
|
|
|
completed = all(
|
|
|
- len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
|
|
+ len(sequences) >= self.max_candidates
|
|
|
+ for sequences in self.finished_sequences
|
|
|
)
|
|
|
return tokens, completed
|
|
|
|
|
@@ -347,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
|
|
|
|
|
|
sum_logprobs = sum_logprobs.cpu()
|
|
|
for i, sequences in enumerate(self.finished_sequences):
|
|
|
- if len(sequences) < self.beam_size:
|
|
|
+ if (
|
|
|
+ len(sequences) < self.beam_size
|
|
|
+ ):
|
|
|
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
|
|
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
|
|
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
|
@@ -355,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
|
|
|
break
|
|
|
|
|
|
tokens: List[List[Tensor]] = [
|
|
|
- [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
|
|
+ [torch.tensor(seq) for seq in sequences.keys()]
|
|
|
+ for sequences in self.finished_sequences
|
|
|
]
|
|
|
sum_logprobs: List[List[float]] = [
|
|
|
list(sequences.values()) for sequences in self.finished_sequences
|
|
@@ -399,7 +433,10 @@ class SuppressTokens(LogitFilter):
|
|
|
|
|
|
class ApplyTimestampRules(LogitFilter):
|
|
|
def __init__(
|
|
|
- self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
|
|
+ self,
|
|
|
+ tokenizer: Tokenizer,
|
|
|
+ sample_begin: int,
|
|
|
+ max_initial_timestamp_index: Optional[int],
|
|
|
):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.sample_begin = sample_begin
|
|
@@ -414,8 +451,12 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
for k in range(tokens.shape[0]):
|
|
|
sampled_tokens = tokens[k, self.sample_begin :]
|
|
|
seq = [t for t in sampled_tokens.tolist()]
|
|
|
- last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
|
|
- penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
|
|
+ last_was_timestamp = (
|
|
|
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
+ penultimate_was_timestamp = (
|
|
|
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
|
|
|
if last_was_timestamp:
|
|
|
if penultimate_was_timestamp:
|
|
@@ -423,7 +464,9 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
else:
|
|
|
logits[k, : self.tokenizer.eot] = -np.inf
|
|
|
|
|
|
- timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
|
|
|
+ timestamps = sampled_tokens[
|
|
|
+ sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
|
|
+ ]
|
|
|
if timestamps.numel() > 0:
|
|
|
|
|
|
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
|
@@ -434,13 +477,17 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
|
|
|
|
|
|
if self.max_initial_timestamp_index is not None:
|
|
|
- last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
|
+ last_allowed = (
|
|
|
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
|
+ )
|
|
|
logits[:, last_allowed + 1 :] = -np.inf
|
|
|
|
|
|
|
|
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
|
|
for k in range(tokens.shape[0]):
|
|
|
- timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
|
|
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
|
|
+ dim=-1
|
|
|
+ )
|
|
|
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
|
|
if timestamp_logprob > max_text_token_logprob:
|
|
|
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
|
@@ -456,7 +503,9 @@ class DecodingTask:
|
|
|
self.model = model
|
|
|
|
|
|
language = options.language or "en"
|
|
|
- tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
|
|
+ tokenizer = get_tokenizer(
|
|
|
+ model.is_multilingual, language=language, task=options.task
|
|
|
+ )
|
|
|
self.tokenizer: Tokenizer = tokenizer
|
|
|
self.options: DecodingOptions = self._verify_options(options)
|
|
|
|
|
@@ -496,9 +545,13 @@ class DecodingTask:
|
|
|
precision = CHUNK_LENGTH / model.dims.n_audio_ctx
|
|
|
max_initial_timestamp_index = None
|
|
|
if options.max_initial_timestamp:
|
|
|
- max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
|
|
+ max_initial_timestamp_index = round(
|
|
|
+ self.options.max_initial_timestamp / precision
|
|
|
+ )
|
|
|
self.logit_filters.append(
|
|
|
- ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
|
|
+ ApplyTimestampRules(
|
|
|
+ tokenizer, self.sample_begin, max_initial_timestamp_index
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
|
@@ -509,7 +562,9 @@ class DecodingTask:
|
|
|
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
|
|
if options.patience is not None and options.beam_size is None:
|
|
|
raise ValueError("patience requires beam_size to be given")
|
|
|
- if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
|
|
+ if options.length_penalty is not None and not (
|
|
|
+ 0 <= options.length_penalty <= 1
|
|
|
+ ):
|
|
|
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
|
|
|
|
|
return options
|
|
@@ -519,7 +574,9 @@ class DecodingTask:
|
|
|
|
|
|
if prefix := self.options.prefix:
|
|
|
prefix_tokens = (
|
|
|
- self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
|
|
+ self.tokenizer.encode(" " + prefix.strip())
|
|
|
+ if isinstance(prefix, str)
|
|
|
+ else prefix
|
|
|
)
|
|
|
if self.sample_len is not None:
|
|
|
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
|
@@ -528,9 +585,15 @@ class DecodingTask:
|
|
|
|
|
|
if prompt := self.options.prompt:
|
|
|
prompt_tokens = (
|
|
|
- self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
|
|
+ self.tokenizer.encode(" " + prompt.strip())
|
|
|
+ if isinstance(prompt, str)
|
|
|
+ else prompt
|
|
|
+ )
|
|
|
+ tokens = (
|
|
|
+ [self.tokenizer.sot_prev]
|
|
|
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
|
|
+ + tokens
|
|
|
)
|
|
|
- tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
|
|
|
|
|
return tuple(tokens)
|
|
|
|
|
@@ -554,7 +617,7 @@ class DecodingTask:
|
|
|
self.tokenizer.translate,
|
|
|
self.tokenizer.sot,
|
|
|
self.tokenizer.sot_prev,
|
|
|
- self.tokenizer.sot_lm
|
|
|
+ self.tokenizer.sot_lm,
|
|
|
]
|
|
|
)
|
|
|
if self.tokenizer.no_speech is not None:
|
|
@@ -567,14 +630,21 @@ class DecodingTask:
|
|
|
if self.options.fp16:
|
|
|
mel = mel.half()
|
|
|
|
|
|
- if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
|
|
+ if mel.shape[-2:] == (
|
|
|
+ self.model.dims.n_audio_ctx,
|
|
|
+ self.model.dims.n_audio_state,
|
|
|
+ ):
|
|
|
|
|
|
audio_features = mel
|
|
|
else:
|
|
|
audio_features = self.model.encoder(mel)
|
|
|
|
|
|
- if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
|
|
- return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
|
|
+ if audio_features.dtype != (
|
|
|
+ torch.float16 if self.options.fp16 else torch.float32
|
|
|
+ ):
|
|
|
+ return TypeError(
|
|
|
+ f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
|
|
+ )
|
|
|
|
|
|
return audio_features
|
|
|
|
|
@@ -583,7 +653,9 @@ class DecodingTask:
|
|
|
lang_probs = None
|
|
|
|
|
|
if self.options.language is None or self.options.task == "lang_id":
|
|
|
- lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
|
|
+ lang_tokens, lang_probs = self.model.detect_language(
|
|
|
+ audio_features, self.tokenizer
|
|
|
+ )
|
|
|
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
|
|
if self.options.language is None:
|
|
|
tokens[:, self.sot_index + 1] = lang_tokens
|
|
@@ -600,7 +672,9 @@ class DecodingTask:
|
|
|
for i in range(self.sample_len):
|
|
|
logits = self.inference.logits(tokens, audio_features)
|
|
|
|
|
|
- if i == 0 and self.tokenizer.no_speech is not None:
|
|
|
+ if (
|
|
|
+ i == 0 and self.tokenizer.no_speech is not None
|
|
|
+ ):
|
|
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
|
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
|
|
|
|
@@ -634,8 +708,12 @@ class DecodingTask:
|
|
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
|
|
if self.options.task == "lang_id":
|
|
|
return [
|
|
|
- DecodingResult(audio_features=features, language=language, language_probs=probs)
|
|
|
- for features, language, probs in zip(audio_features, languages, language_probs)
|
|
|
+ DecodingResult(
|
|
|
+ audio_features=features, language=language, language_probs=probs
|
|
|
+ )
|
|
|
+ for features, language, probs in zip(
|
|
|
+ audio_features, languages, language_probs
|
|
|
+ )
|
|
|
]
|
|
|
|
|
|
|
|
@@ -656,7 +734,8 @@ class DecodingTask:
|
|
|
|
|
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
|
|
tokens: List[List[Tensor]] = [
|
|
|
- [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
|
|
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
|
|
+ for s in tokens
|
|
|
]
|
|
|
|
|
|
|
|
@@ -665,9 +744,18 @@ class DecodingTask:
|
|
|
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
|
|
|
|
|
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
|
|
- avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
|
|
+ avg_logprobs: List[float] = [
|
|
|
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
|
|
+ ]
|
|
|
|
|
|
- fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
|
|
+ fields = (
|
|
|
+ texts,
|
|
|
+ languages,
|
|
|
+ tokens,
|
|
|
+ audio_features,
|
|
|
+ avg_logprobs,
|
|
|
+ no_speech_probs,
|
|
|
+ )
|
|
|
if len(set(map(len, fields))) != 1:
|
|
|
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
|
|
|
|
@@ -682,12 +770,16 @@ class DecodingTask:
|
|
|
temperature=self.options.temperature,
|
|
|
compression_ratio=compression_ratio(text),
|
|
|
)
|
|
|
- for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
|
|
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
|
|
+ *fields
|
|
|
+ )
|
|
|
]
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
-def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
|
|
+def decode(
|
|
|
+ model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
|
|
|
+) -> Union[DecodingResult, List[DecodingResult]]:
|
|
|
"""
|
|
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
|
|
|