Ver código fonte

Fix alignment between the segments and the list of words (#1087)

* Fix alignment between the segments and the list of words

* Ensure the word index does not overflow
Guillaume Klein 1 ano atrás
pai
commit
671ac5a4ce
1 arquivos alterados com 30 adições e 22 exclusões
  1. 30 22
      whisper/timing.py

+ 30 - 22
whisper/timing.py

@@ -1,3 +1,4 @@
+import itertools
 import subprocess
 import warnings
 from dataclasses import dataclass
@@ -290,34 +291,41 @@ def add_word_timestamps(
     if len(segments) == 0:
         return
 
-    text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
+    text_tokens_per_segment = [
+        [token for token in segment["tokens"] if token < tokenizer.eot]
+        for segment in segments
+    ]
+
+    text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
     alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
 
     time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
-    segment_lengths = [len(s["tokens"]) for s in segments]
-    token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
-
-    for segment in segments:
-        segment["words"] = []
-
-    word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
-    for i, timing in enumerate(alignment):
-        if timing.word:
-            segment = segments[token_sources[word_boundaries[i]]]
-            start = round(time_offset + timing.start, 2)
-            end = round(time_offset + timing.end, 2)
-            segment["words"].append(
-                dict(
-                    word=timing.word,
-                    start=start,
-                    end=end,
-                    probability=timing.probability,
+    word_index = 0
+
+    for segment, text_tokens in zip(segments, text_tokens_per_segment):
+        saved_tokens = 0
+        words = []
+
+        while word_index < len(alignment) and saved_tokens < len(text_tokens):
+            timing = alignment[word_index]
+
+            if timing.word:
+                words.append(
+                    dict(
+                        word=timing.word,
+                        start=round(time_offset + timing.start, 2),
+                        end=round(time_offset + timing.end, 2),
+                        probability=timing.probability,
+                    )
                 )
-            )
 
-    for segment in segments:
-        if len(words := segment["words"]) > 0:
+            saved_tokens += len(timing.tokens)
+            word_index += 1
+
+        if len(words) > 0:
             # adjust the segment-level timestamps based on the word-level timestamps
             segment["start"] = words[0]["start"]
             segment["end"] = words[-1]["end"]
+
+        segment["words"] = words