Browse Source

attempt to fix the repetition/hallucination issue identified in #1046 (#1052)

* attempt to fix the repetition/hallucination issue identified in #1046

* zero-pad the audio instead of spectrogram

* formatting fix

* delete debug print
Jong Wook Kim 2 years ago
parent
commit
919a713499
2 changed files with 38 additions and 27 deletions
  1. 17 6
      whisper/audio.py
  2. 21 21
      whisper/transcribe.py

+ 17 - 6
whisper/audio.py

@@ -1,6 +1,6 @@
 import os
 from functools import lru_cache
-from typing import Union
+from typing import Optional, Union
 
 import ffmpeg
 import numpy as np
@@ -15,10 +15,8 @@ N_FFT = 400
 N_MELS = 80
 HOP_LENGTH = 160
 CHUNK_LENGTH = 30
-N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
-N_FRAMES = exact_div(
-    N_SAMPLES, HOP_LENGTH
-)  # 3000: number of frames in a mel spectrogram input
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
+N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000 frames in a mel spectrogram input
 
 N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
 FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
@@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
 
 
 def log_mel_spectrogram(
-    audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
+    audio: Union[str, np.ndarray, torch.Tensor],
+    n_mels: int = N_MELS,
+    padding: int = 0,
+    device: Optional[Union[str, torch.device]] = None,
 ):
     """
     Compute the log-Mel spectrogram of
@@ -113,6 +114,12 @@ def log_mel_spectrogram(
     n_mels: int
         The number of Mel-frequency filters, only 80 is supported
 
+    padding: int
+        Number of zero samples to pad to the right
+
+    device: Optional[Union[str, torch.device]]
+        If given, the audio tensor is moved to this device before STFT
+
     Returns
     -------
     torch.Tensor, shape = (80, n_frames)
@@ -123,6 +130,10 @@ def log_mel_spectrogram(
             audio = load_audio(audio)
         audio = torch.from_numpy(audio)
 
+    if device is not None:
+        audio = audio.to(device)
+    if padding > 0:
+        audio = F.pad(audio, (0, padding))
     window = torch.hann_window(N_FFT).to(audio.device)
     stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
     magnitudes = stft[..., :-1].abs() ** 2

+ 21 - 21
whisper/transcribe.py

@@ -11,6 +11,7 @@ from .audio import (
     FRAMES_PER_SECOND,
     HOP_LENGTH,
     N_FRAMES,
+    N_SAMPLES,
     SAMPLE_RATE,
     log_mel_spectrogram,
     pad_or_trim,
@@ -116,7 +117,9 @@ def transcribe(
     if dtype == torch.float32:
         decode_options["fp16"] = False
 
-    mel = log_mel_spectrogram(audio)
+    # Pad 30-seconds of silence to the input audio, for slicing
+    mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
+    content_frames = mel.shape[-1] - N_FRAMES
 
     if decode_options.get("language", None) is None:
         if not model.is_multilingual:
@@ -212,14 +215,13 @@ def transcribe(
         }
 
     # show the progress bar when verbose is False (if True, transcribed text will be printed)
-    num_frames = mel.shape[-1]
     with tqdm.tqdm(
-        total=num_frames, unit="frames", disable=verbose is not False
+        total=content_frames, unit="frames", disable=verbose is not False
     ) as pbar:
-        while seek < num_frames:
+        while seek < content_frames:
             time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
-            mel_segment = mel[:, seek:]
-            segment_size = min(mel_segment.shape[-1], N_FRAMES)
+            mel_segment = mel[:, seek : seek + N_FRAMES]
+            segment_size = min(N_FRAMES, content_frames - seek)
             segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
             mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
 
@@ -246,20 +248,18 @@ def transcribe(
             current_tokens = []
 
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
-            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
-                0
-            ].add_(1)
-            if (
-                len(consecutive) > 0
-            ):  # if the output contains two consecutive timestamp tokens
-                if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
-                    False,
-                    True,
-                ]:
-                    consecutive = consecutive.tolist() + [len(tokens)]
+            single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
+
+            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
+            consecutive.add_(1)
+            if len(consecutive) > 0:
+                # if the output contains two consecutive timestamp tokens
+                slices = consecutive.tolist()
+                if single_timestamp_ending:
+                    slices.append(len(tokens))
 
                 last_slice = 0
-                for current_slice in consecutive:
+                for current_slice in slices:
                     sliced_tokens = tokens[last_slice:current_slice]
                     start_timestamp_pos = (
                         sliced_tokens[0].item() - tokenizer.timestamp_begin
@@ -278,7 +278,7 @@ def transcribe(
                     current_tokens.append(sliced_tokens.tolist())
                     last_slice = current_slice
 
-                if ended_with_single_timestamp:
+                if single_timestamp_ending:
                     # single timestamp at the end means no speech after the last timestamp.
                     seek += segment_size
                 else:
@@ -329,7 +329,7 @@ def transcribe(
                 word_end_timestamps = [
                     w["end"] for s in current_segments for w in s["words"]
                 ]
-                if len(consecutive) > 0 and len(word_end_timestamps) > 0:
+                if not single_timestamp_ending and len(word_end_timestamps) > 0:
                     seek_shift = round(
                         (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
                     )
@@ -356,7 +356,7 @@ def transcribe(
             )
 
             # update progress bar
-            pbar.update(min(num_frames, seek) - previous_seek)
+            pbar.update(min(content_frames, seek) - previous_seek)
 
     return dict(
         text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),