|
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
+import tqdm
|
|
|
|
|
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
|
|
from .decoding import DecodingOptions, DecodingResult
|
|
@@ -87,7 +88,7 @@ def transcribe(
|
|
|
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
|
|
_, probs = model.detect_language(segment)
|
|
|
decode_options["language"] = max(probs, key=probs.get)
|
|
|
- print(f"Detected language: {LANGUAGES[decode_options['language']]}")
|
|
|
+ print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
|
|
|
|
|
mel = mel.unsqueeze(0)
|
|
|
language = decode_options["language"]
|
|
@@ -160,72 +161,81 @@ def transcribe(
|
|
|
if verbose:
|
|
|
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
|
|
|
|
|
- while seek < mel.shape[-1]:
|
|
|
- timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
- segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
|
|
- segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
|
|
-
|
|
|
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
- result = decode_with_fallback(segment)[0]
|
|
|
- tokens = torch.tensor(result.tokens)
|
|
|
-
|
|
|
- if no_speech_threshold is not None:
|
|
|
- # no voice activity check
|
|
|
- should_skip = result.no_speech_prob > no_speech_threshold
|
|
|
- if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
|
|
- # don't skip if the logprob is high enough, despite the no_speech_prob
|
|
|
- should_skip = False
|
|
|
-
|
|
|
- if should_skip:
|
|
|
- seek += segment.shape[-1] # fast-forward to the next segment boundary
|
|
|
- continue
|
|
|
-
|
|
|
- timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
|
|
- consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
|
|
- if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
|
|
- last_slice = 0
|
|
|
- for current_slice in consecutive:
|
|
|
- sliced_tokens = tokens[last_slice:current_slice]
|
|
|
- start_timestamp_position = (
|
|
|
- sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
|
- )
|
|
|
- end_timestamp_position = (
|
|
|
- sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
|
|
+ # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
|
|
+ num_frames = mel.shape[-1]
|
|
|
+ previous_seek_value = seek
|
|
|
+
|
|
|
+ with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose) as pbar:
|
|
|
+ while seek < num_frames:
|
|
|
+ timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
+ segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
|
|
+ segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
|
|
+
|
|
|
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
+ result = decode_with_fallback(segment)[0]
|
|
|
+ tokens = torch.tensor(result.tokens)
|
|
|
+
|
|
|
+ if no_speech_threshold is not None:
|
|
|
+ # no voice activity check
|
|
|
+ should_skip = result.no_speech_prob > no_speech_threshold
|
|
|
+ if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
|
|
+ # don't skip if the logprob is high enough, despite the no_speech_prob
|
|
|
+ should_skip = False
|
|
|
+
|
|
|
+ if should_skip:
|
|
|
+ seek += segment.shape[-1] # fast-forward to the next segment boundary
|
|
|
+ continue
|
|
|
+
|
|
|
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
|
|
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
|
|
+ if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
|
|
+ last_slice = 0
|
|
|
+ for current_slice in consecutive:
|
|
|
+ sliced_tokens = tokens[last_slice:current_slice]
|
|
|
+ start_timestamp_position = (
|
|
|
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
+ end_timestamp_position = (
|
|
|
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
|
|
+ )
|
|
|
+ add_segment(
|
|
|
+ start=timestamp_offset + start_timestamp_position * time_precision,
|
|
|
+ end=timestamp_offset + end_timestamp_position * time_precision,
|
|
|
+ text_tokens=sliced_tokens[1:-1],
|
|
|
+ result=result,
|
|
|
+ )
|
|
|
+ last_slice = current_slice
|
|
|
+ last_timestamp_position = (
|
|
|
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
|
|
)
|
|
|
+ seek += last_timestamp_position * input_stride
|
|
|
+ all_tokens.extend(tokens[: last_slice + 1].tolist())
|
|
|
+ else:
|
|
|
+ duration = segment_duration
|
|
|
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
|
+ if len(timestamps) > 0:
|
|
|
+ # no consecutive timestamps but it has a timestamp; use the last one.
|
|
|
+ # single timestamp at the end means no speech after the last timestamp.
|
|
|
+ last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
|
|
+ duration = last_timestamp_position * time_precision
|
|
|
+
|
|
|
add_segment(
|
|
|
- start=timestamp_offset + start_timestamp_position * time_precision,
|
|
|
- end=timestamp_offset + end_timestamp_position * time_precision,
|
|
|
- text_tokens=sliced_tokens[1:-1],
|
|
|
+ start=timestamp_offset,
|
|
|
+ end=timestamp_offset + duration,
|
|
|
+ text_tokens=tokens,
|
|
|
result=result,
|
|
|
)
|
|
|
- last_slice = current_slice
|
|
|
- last_timestamp_position = (
|
|
|
- tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
|
|
- )
|
|
|
- seek += last_timestamp_position * input_stride
|
|
|
- all_tokens.extend(tokens[: last_slice + 1].tolist())
|
|
|
- else:
|
|
|
- duration = segment_duration
|
|
|
- timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
|
- if len(timestamps) > 0:
|
|
|
- # no consecutive timestamps but it has a timestamp; use the last one.
|
|
|
- # single timestamp at the end means no speech after the last timestamp.
|
|
|
- last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
|
|
- duration = last_timestamp_position * time_precision
|
|
|
-
|
|
|
- add_segment(
|
|
|
- start=timestamp_offset,
|
|
|
- end=timestamp_offset + duration,
|
|
|
- text_tokens=tokens,
|
|
|
- result=result,
|
|
|
- )
|
|
|
-
|
|
|
- seek += segment.shape[-1]
|
|
|
- all_tokens.extend(tokens.tolist())
|
|
|
-
|
|
|
- if not condition_on_previous_text or result.temperature > 0.5:
|
|
|
- # do not feed the prompt tokens if a high temperature was used
|
|
|
- prompt_reset_since = len(all_tokens)
|
|
|
+
|
|
|
+ seek += segment.shape[-1]
|
|
|
+ all_tokens.extend(tokens.tolist())
|
|
|
+
|
|
|
+ if not condition_on_previous_text or result.temperature > 0.5:
|
|
|
+ # do not feed the prompt tokens if a high temperature was used
|
|
|
+ prompt_reset_since = len(all_tokens)
|
|
|
+
|
|
|
+ # update progress bar
|
|
|
+ pbar.update(min(num_frames, seek) - previous_seek_value)
|
|
|
+ previous_seek_value = seek
|
|
|
|
|
|
return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
|
|
|
|