فهرست منبع

Improve timestamp heuristics. (#1461)

* Improve timestamp heuristics.

* Track pauses with last_speech_timestamp
ryanheise 1 سال پیش
والد
کامیت
f572f2161b
2فایلهای تغییر یافته به همراه60 افزوده شده و 28 حذف شده
  1. 56 28
      whisper/timing.py
  2. 4 0
      whisper/transcribe.py

+ 56 - 28
whisper/timing.py

@@ -225,28 +225,6 @@ def find_alignment(
         for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
     ]
 
-    # hack: truncate long words at the start of a window and the start of a sentence.
-    # a better segmentation algorithm based on VAD should be able to replace this.
-    word_durations = end_times - start_times
-    word_durations = word_durations[word_durations.nonzero()]
-    if len(word_durations) > 0:
-        median_duration = np.median(word_durations)
-        max_duration = median_duration * 2
-        sentence_end_marks = ".。!!??"
-        # ensure words at sentence boundaries are not longer than twice the median word duration.
-        for i in range(1, len(start_times)):
-            if end_times[i] - start_times[i] > max_duration:
-                if words[i] in sentence_end_marks:
-                    end_times[i] = start_times[i] + max_duration
-                elif words[i - 1] in sentence_end_marks:
-                    start_times[i] = end_times[i] - max_duration
-        # ensure the first and second word is not longer than twice the median word duration.
-        if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration:
-            if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration:
-                boundary = max(end_times[1] / 2, end_times[1] - max_duration)
-                end_times[0] = start_times[1] = boundary
-            start_times[0] = max(0, end_times[0] - max_duration)
-
     return [
         WordTiming(word, tokens, start, end, probability)
         for word, tokens, start, end, probability in zip(
@@ -298,6 +276,7 @@ def add_word_timestamps(
     num_frames: int,
     prepend_punctuations: str = "\"'“¿([{-",
     append_punctuations: str = "\"'.。,,!!??::”)]}、",
+    last_speech_timestamp: float,
     **kwargs,
 ):
     if len(segments) == 0:
@@ -310,6 +289,25 @@ def add_word_timestamps(
 
     text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
     alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
+    word_durations = np.array([t.end - t.start for t in alignment])
+    word_durations = word_durations[word_durations.nonzero()]
+    median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
+    max_duration = median_duration * 2
+
+    # hack: truncate long words at sentence boundaries.
+    # a better segmentation algorithm based on VAD should be able to replace this.
+    if len(word_durations) > 0:
+        median_duration = np.median(word_durations)
+        max_duration = median_duration * 2
+        sentence_end_marks = ".。!!??"
+        # ensure words at sentence boundaries are not longer than twice the median word duration.
+        for i in range(1, len(alignment)):
+            if alignment[i].end - alignment[i].start > max_duration:
+                if alignment[i].word in sentence_end_marks:
+                    alignment[i].end = alignment[i].start + max_duration
+                elif alignment[i - 1].word in sentence_end_marks:
+                    alignment[i].start = alignment[i].end - max_duration
+
     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
 
     time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
@@ -335,18 +333,48 @@ def add_word_timestamps(
             saved_tokens += len(timing.tokens)
             word_index += 1
 
+        # hack: truncate long words at segment boundaries.
+        # a better segmentation algorithm based on VAD should be able to replace this.
         if len(words) > 0:
-            segment["start"] = words[0]["start"]
-            # hack: prefer the segment-level end timestamp if the last word is too long.
-            # a better segmentation algorithm based on VAD should be able to replace this.
+            # ensure the first and second word after a pause is not longer than
+            # twice the median word duration.
+            if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
+                words[0]["end"] - words[0]["start"] > max_duration
+                or (
+                    len(words) > 1
+                    and words[1]["end"] - words[0]["start"] > max_duration * 2
+                )
+            ):
+                if (
+                    len(words) > 1
+                    and words[1]["end"] - words[1]["start"] > max_duration
+                ):
+                    boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
+                    words[0]["end"] = words[1]["start"] = boundary
+                words[0]["start"] = max(0, words[0]["end"] - max_duration)
+
+            # prefer the segment-level start timestamp if the first word is too long.
+            if (
+                segment["start"] < words[0]["end"]
+                and segment["start"] - 0.5 > words[0]["start"]
+            ):
+                words[0]["start"] = max(
+                    0, min(words[0]["end"] - median_duration, segment["start"])
+                )
+            else:
+                segment["start"] = words[0]["start"]
+
+            # prefer the segment-level end timestamp if the last word is too long.
             if (
                 segment["end"] > words[-1]["start"]
                 and segment["end"] + 0.5 < words[-1]["end"]
             ):
-                # adjust the word-level timestamps based on the segment-level timestamps
-                words[-1]["end"] = segment["end"]
+                words[-1]["end"] = max(
+                    words[-1]["start"] + median_duration, segment["end"]
+                )
             else:
-                # adjust the segment-level timestamps based on the word-level timestamps
                 segment["end"] = words[-1]["end"]
 
+            last_speech_timestamp = segment["end"]
+
         segment["words"] = words

+ 4 - 0
whisper/transcribe.py

@@ -222,6 +222,7 @@ def transcribe(
     with tqdm.tqdm(
         total=content_frames, unit="frames", disable=verbose is not False
     ) as pbar:
+        last_speech_timestamp = 0.0
         while seek < content_frames:
             time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
             mel_segment = mel[:, seek : seek + N_FRAMES]
@@ -321,10 +322,13 @@ def transcribe(
                     num_frames=segment_size,
                     prepend_punctuations=prepend_punctuations,
                     append_punctuations=append_punctuations,
+                    last_speech_timestamp=last_speech_timestamp,
                 )
                 word_end_timestamps = [
                     w["end"] for s in current_segments for w in s["words"]
                 ]
+                if len(word_end_timestamps) > 0:
+                    last_speech_timestamp = word_end_timestamps[-1]
                 if not single_timestamp_ending and len(word_end_timestamps) > 0:
                     seek_shift = round(
                         (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND