|
@@ -108,7 +108,7 @@ class DecodingResult:
|
|
|
tokens: List[int] = field(default_factory=list)
|
|
|
text: str = ""
|
|
|
avg_logprob: float = np.nan
|
|
|
- no_caption_prob: float = np.nan
|
|
|
+ no_speech_prob: float = np.nan
|
|
|
temperature: float = np.nan
|
|
|
compression_ratio: float = np.nan
|
|
|
|
|
@@ -543,9 +543,9 @@ class DecodingTask:
|
|
|
suppress_tokens.extend(
|
|
|
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
|
|
)
|
|
|
- if self.tokenizer.no_captions is not None:
|
|
|
- # no-captions probability is collected separately
|
|
|
- suppress_tokens.append(self.tokenizer.no_captions)
|
|
|
+ if self.tokenizer.no_speech is not None:
|
|
|
+ # no-speech probability is collected separately
|
|
|
+ suppress_tokens.append(self.tokenizer.no_speech)
|
|
|
|
|
|
return tuple(sorted(set(suppress_tokens)))
|
|
|
|
|
@@ -580,15 +580,15 @@ class DecodingTask:
|
|
|
assert audio_features.shape[0] == tokens.shape[0]
|
|
|
n_batch = tokens.shape[0]
|
|
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
|
|
- no_caption_probs = [np.nan] * n_batch
|
|
|
+ no_speech_probs = [np.nan] * n_batch
|
|
|
|
|
|
try:
|
|
|
for i in range(self.sample_len):
|
|
|
logits = self.inference.logits(tokens, audio_features)
|
|
|
|
|
|
- if i == 0 and self.tokenizer.no_captions is not None: # save no_caption_probs
|
|
|
+ if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
|
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
|
|
- no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist()
|
|
|
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
|
|
|
|
|
# now we need to consider the logits at the last token only
|
|
|
logits = logits[:, -1]
|
|
@@ -605,7 +605,7 @@ class DecodingTask:
|
|
|
finally:
|
|
|
self.inference.cleanup_caching()
|
|
|
|
|
|
- return tokens, sum_logprobs, no_caption_probs
|
|
|
+ return tokens, sum_logprobs, no_speech_probs
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def run(self, mel: Tensor) -> List[DecodingResult]:
|
|
@@ -629,12 +629,12 @@ class DecodingTask:
|
|
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
|
|
|
|
|
# call the main sampling loop
|
|
|
- tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens)
|
|
|
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
|
|
|
|
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
|
|
audio_features = audio_features[:: self.n_group]
|
|
|
- no_caption_probs = no_caption_probs[:: self.n_group]
|
|
|
- assert audio_features.shape[0] == len(no_caption_probs) == n_audio
|
|
|
+ no_speech_probs = no_speech_probs[:: self.n_group]
|
|
|
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
|
|
|
|
|
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
|
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
|
@@ -653,7 +653,7 @@ class DecodingTask:
|
|
|
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
|
|
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
|
|
|
|
|
- fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs)
|
|
|
+ fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
|
|
if len(set(map(len, fields))) != 1:
|
|
|
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
|
|
|
|
@@ -664,11 +664,11 @@ class DecodingTask:
|
|
|
tokens=tokens,
|
|
|
text=text,
|
|
|
avg_logprob=avg_logprob,
|
|
|
- no_caption_prob=no_caption_prob,
|
|
|
+ no_speech_prob=no_speech_prob,
|
|
|
temperature=self.options.temperature,
|
|
|
compression_ratio=compression_ratio(text),
|
|
|
)
|
|
|
- for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields)
|
|
|
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
|
|
]
|
|
|
|
|
|
|