浏览代码

Squash long words at window and sentence boundaries. (#1114)

* Squash long words at window and sentence boundaries.

* Formatting requirements.

* Fix squashing logic to point to correct words.

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
ryanheise 2 年之前
父节点
当前提交
255887f219
共有 1 个文件被更改,包括 25 次插入7 次删除
  1. 25 7
      whisper/timing.py

+ 25 - 7
whisper/timing.py

@@ -225,17 +225,26 @@ def find_alignment(
         for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
     ]
 
-    # hack: ensure the first and second word is not longer than twice the median word duration.
+    # hack: truncate long words at the start of a window and the start of a sentence.
     # a better segmentation algorithm based on VAD should be able to replace this.
     word_durations = end_times - start_times
     word_durations = word_durations[word_durations.nonzero()]
     if len(word_durations) > 0:
         median_duration = np.median(word_durations)
         max_duration = median_duration * 2
-        if len(word_durations) >= 2 and word_durations[1] > max_duration:
-            boundary = max(end_times[2] / 2, end_times[2] - max_duration)
-            end_times[0] = start_times[1] = boundary
-        if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
+        sentence_end_marks = ".。!!??"
+        # ensure words at sentence boundaries are not longer than twice the median word duration.
+        for i in range(1, len(start_times)):
+            if end_times[i] - start_times[i] > max_duration:
+                if words[i] in sentence_end_marks:
+                    end_times[i] = start_times[i] + max_duration
+                elif words[i - 1] in sentence_end_marks:
+                    start_times[i] = end_times[i] - max_duration
+        # ensure the first and second word is not longer than twice the median word duration.
+        if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration:
+            if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration:
+                boundary = max(end_times[1] / 2, end_times[1] - max_duration)
+                end_times[0] = start_times[1] = boundary
             start_times[0] = max(0, end_times[0] - max_duration)
 
     return [
@@ -327,8 +336,17 @@ def add_word_timestamps(
             word_index += 1
 
         if len(words) > 0:
-            # adjust the segment-level timestamps based on the word-level timestamps
             segment["start"] = words[0]["start"]
-            segment["end"] = words[-1]["end"]
+            # hack: prefer the segment-level end timestamp if the last word is too long.
+            # a better segmentation algorithm based on VAD should be able to replace this.
+            if (
+                segment["end"] > words[-1]["start"]
+                and segment["end"] + 0.5 < words[-1]["end"]
+            ):
+                # adjust the word-level timestamps based on the segment-level timestamps
+                words[-1]["end"] = segment["end"]
+            else:
+                # adjust the segment-level timestamps based on the word-level timestamps
+                segment["end"] = words[-1]["end"]
 
         segment["words"] = words