|
@@ -11,6 +11,7 @@ from .audio import (
|
|
|
FRAMES_PER_SECOND,
|
|
|
HOP_LENGTH,
|
|
|
N_FRAMES,
|
|
|
+ N_SAMPLES,
|
|
|
SAMPLE_RATE,
|
|
|
log_mel_spectrogram,
|
|
|
pad_or_trim,
|
|
@@ -116,7 +117,9 @@ def transcribe(
|
|
|
if dtype == torch.float32:
|
|
|
decode_options["fp16"] = False
|
|
|
|
|
|
- mel = log_mel_spectrogram(audio)
|
|
|
+ # Pad 30-seconds of silence to the input audio, for slicing
|
|
|
+ mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
|
|
+ content_frames = mel.shape[-1] - N_FRAMES
|
|
|
|
|
|
if decode_options.get("language", None) is None:
|
|
|
if not model.is_multilingual:
|
|
@@ -212,14 +215,13 @@ def transcribe(
|
|
|
}
|
|
|
|
|
|
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
|
|
- num_frames = mel.shape[-1]
|
|
|
with tqdm.tqdm(
|
|
|
- total=num_frames, unit="frames", disable=verbose is not False
|
|
|
+ total=content_frames, unit="frames", disable=verbose is not False
|
|
|
) as pbar:
|
|
|
- while seek < num_frames:
|
|
|
+ while seek < content_frames:
|
|
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
- mel_segment = mel[:, seek:]
|
|
|
- segment_size = min(mel_segment.shape[-1], N_FRAMES)
|
|
|
+ mel_segment = mel[:, seek : seek + N_FRAMES]
|
|
|
+ segment_size = min(N_FRAMES, content_frames - seek)
|
|
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
|
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
|
|
|
|
@@ -246,20 +248,18 @@ def transcribe(
|
|
|
current_tokens = []
|
|
|
|
|
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
|
|
- consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
|
|
|
- 0
|
|
|
- ].add_(1)
|
|
|
- if (
|
|
|
- len(consecutive) > 0
|
|
|
- ): # if the output contains two consecutive timestamp tokens
|
|
|
- if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
|
|
|
- False,
|
|
|
- True,
|
|
|
- ]:
|
|
|
- consecutive = consecutive.tolist() + [len(tokens)]
|
|
|
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
|
|
+
|
|
|
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
|
|
+ consecutive.add_(1)
|
|
|
+ if len(consecutive) > 0:
|
|
|
+ # if the output contains two consecutive timestamp tokens
|
|
|
+ slices = consecutive.tolist()
|
|
|
+ if single_timestamp_ending:
|
|
|
+ slices.append(len(tokens))
|
|
|
|
|
|
last_slice = 0
|
|
|
- for current_slice in consecutive:
|
|
|
+ for current_slice in slices:
|
|
|
sliced_tokens = tokens[last_slice:current_slice]
|
|
|
start_timestamp_pos = (
|
|
|
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
|
@@ -278,7 +278,7 @@ def transcribe(
|
|
|
current_tokens.append(sliced_tokens.tolist())
|
|
|
last_slice = current_slice
|
|
|
|
|
|
- if ended_with_single_timestamp:
|
|
|
+ if single_timestamp_ending:
|
|
|
# single timestamp at the end means no speech after the last timestamp.
|
|
|
seek += segment_size
|
|
|
else:
|
|
@@ -329,7 +329,7 @@ def transcribe(
|
|
|
word_end_timestamps = [
|
|
|
w["end"] for s in current_segments for w in s["words"]
|
|
|
]
|
|
|
- if len(consecutive) > 0 and len(word_end_timestamps) > 0:
|
|
|
+ if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
|
|
seek_shift = round(
|
|
|
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
|
|
)
|
|
@@ -356,7 +356,7 @@ def transcribe(
|
|
|
)
|
|
|
|
|
|
# update progress bar
|
|
|
- pbar.update(min(num_frames, seek) - previous_seek)
|
|
|
+ pbar.update(min(content_frames, seek) - previous_seek)
|
|
|
|
|
|
return dict(
|
|
|
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|