|
@@ -92,41 +92,37 @@ def transcribe(
|
|
|
if verbose is not None:
|
|
|
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
|
|
|
|
|
- mel = mel.unsqueeze(0)
|
|
|
language = decode_options["language"]
|
|
|
task = decode_options.get("task", "transcribe")
|
|
|
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
|
|
|
|
|
- def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]:
|
|
|
+ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
|
|
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
|
|
- kwargs = {**decode_options}
|
|
|
- t = temperatures[0]
|
|
|
- if t == 0:
|
|
|
- best_of = kwargs.pop("best_of", None)
|
|
|
- else:
|
|
|
- best_of = kwargs.get("best_of", None)
|
|
|
-
|
|
|
- options = DecodingOptions(**kwargs, temperature=t)
|
|
|
- results = model.decode(segment, options)
|
|
|
-
|
|
|
- kwargs.pop("beam_size", None) # no beam search for t > 0
|
|
|
- kwargs.pop("patience", None) # no patience for t > 0
|
|
|
- kwargs["best_of"] = best_of # enable best_of for t > 0
|
|
|
- for t in temperatures[1:]:
|
|
|
- needs_fallback = [
|
|
|
- compression_ratio_threshold is not None
|
|
|
- and result.compression_ratio > compression_ratio_threshold
|
|
|
- or logprob_threshold is not None
|
|
|
- and result.avg_logprob < logprob_threshold
|
|
|
- for result in results
|
|
|
- ]
|
|
|
- if any(needs_fallback):
|
|
|
- options = DecodingOptions(**kwargs, temperature=t)
|
|
|
- retries = model.decode(segment[needs_fallback], options)
|
|
|
- for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
|
|
|
- results[original_index] = retries[retry_index]
|
|
|
-
|
|
|
- return results
|
|
|
+ decode_result = None
|
|
|
+
|
|
|
+ for t in temperatures:
|
|
|
+ kwargs = {**decode_options}
|
|
|
+ if t > 0:
|
|
|
+ # disable beam_size and patience when t > 0
|
|
|
+ kwargs.pop("beam_size", None)
|
|
|
+ kwargs.pop("patience", None)
|
|
|
+ else:
|
|
|
+ # disable best_of when t == 0
|
|
|
+ kwargs.pop("best_of", None)
|
|
|
+
|
|
|
+ options = DecodingOptions(**kwargs, temperature=t)
|
|
|
+ decode_result = model.decode(segment, options)
|
|
|
+
|
|
|
+ needs_fallback = False
|
|
|
+ if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
|
|
+ needs_fallback = True # too repetitive
|
|
|
+ if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
|
|
+ needs_fallback = True # average log probability is too low
|
|
|
+
|
|
|
+ if not needs_fallback:
|
|
|
+ break
|
|
|
+
|
|
|
+ return decode_result
|
|
|
|
|
|
seek = 0
|
|
|
input_stride = exact_div(
|
|
@@ -175,11 +171,11 @@ def transcribe(
|
|
|
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
|
|
while seek < num_frames:
|
|
|
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
- segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
|
|
+ segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
|
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
|
|
|
|
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
- result = decode_with_fallback(segment)[0]
|
|
|
+ result: DecodingResult = decode_with_fallback(segment)
|
|
|
tokens = torch.tensor(result.tokens)
|
|
|
|
|
|
if no_speech_threshold is not None:
|