Browse Source

fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060)

Jong Wook Kim 2 years ago
parent
commit
38f2f4d99d
3 changed files with 14 additions and 11 deletions
  1. 1 0
      tests/test_transcribe.py
  2. 1 1
      whisper/timing.py
  3. 12 10
      whisper/transcribe.py

+ 1 - 0
tests/test_transcribe.py

@@ -17,6 +17,7 @@ def test_transcribe(model_name: str):
         audio_path, language=language, temperature=0.0, word_timestamps=True
     )
     assert result["language"] == "en"
+    assert result["text"] == "".join([s["text"] for s in result["segments"]])
 
     transcription = result["text"].lower()
     assert "my fellow americans" in transcription

+ 1 - 1
whisper/timing.py

@@ -290,7 +290,7 @@ def add_word_timestamps(
     if len(segments) == 0:
         return
 
-    text_tokens = [t for segment in segments for t in segment["tokens"]]
+    text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
     alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
 

+ 12 - 10
whisper/transcribe.py

@@ -200,14 +200,14 @@ def transcribe(
     def new_segment(
         *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
     ):
-        text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
+        tokens = tokens.tolist()
+        text_tokens = [token for token in tokens if token < tokenizer.eot]
         return {
-            "id": len(all_segments),
             "seek": seek,
             "start": start,
             "end": end,
             "text": tokenizer.decode(text_tokens),
-            "tokens": text_tokens,
+            "tokens": tokens,
             "temperature": result.temperature,
             "avg_logprob": result.avg_logprob,
             "compression_ratio": result.compression_ratio,
@@ -245,7 +245,6 @@ def transcribe(
 
             previous_seek = seek
             current_segments = []
-            current_tokens = []
 
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
             single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -275,7 +274,6 @@ def transcribe(
                             result=result,
                         )
                     )
-                    current_tokens.append(sliced_tokens.tolist())
                     last_slice = current_slice
 
                 if single_timestamp_ending:
@@ -287,7 +285,6 @@ def transcribe(
                         tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                     )
                     seek += last_timestamp_pos * input_stride
-                all_tokens.extend(tokens[: last_slice + 1].tolist())
             else:
                 duration = segment_duration
                 timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@@ -309,7 +306,6 @@ def transcribe(
                         result=result,
                     )
                 )
-                current_tokens.append(tokens.tolist())
                 seek += segment_size
 
             if not condition_on_previous_text or result.temperature > 0.5:
@@ -348,11 +344,17 @@ def transcribe(
                     segment["text"] = ""
                     segment["tokens"] = []
                     segment["words"] = []
-                    current_tokens[i] = []
 
-            all_segments.extend(current_segments)
+            all_segments.extend(
+                [
+                    {"id": i, **segment}
+                    for i, segment in enumerate(
+                        current_segments, start=len(all_segments)
+                    )
+                ]
+            )
             all_tokens.extend(
-                [token for segment in current_tokens for token in segment]
+                [token for segment in current_segments for token in segment["tokens"]]
             )
 
             # update progress bar