Bläddra i källkod

Decoding improvements (#1033)

* suppress task tokens (transcribe/translate)

* not ignoring the last segment ending with one timestamp
Jong Wook Kim 2 år sedan
förälder
incheckning
eab8d920ed
3 ändrade filer med 30 tillägg och 16 borttagningar
  1. 7 1
      whisper/decoding.py
  2. 8 0
      whisper/tokenizer.py
  3. 15 15
      whisper/transcribe.py

+ 7 - 1
whisper/decoding.py

@@ -549,7 +549,13 @@ class DecodingTask:
             assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
 
         suppress_tokens.extend(
-            [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
+            [
+                self.tokenizer.transcribe,
+                self.tokenizer.translate,
+                self.tokenizer.sot,
+                self.tokenizer.sot_prev,
+                self.tokenizer.sot_lm
+            ]
         )
         if self.tokenizer.no_speech is not None:
             # no-speech probability is collected separately

+ 8 - 0
whisper/tokenizer.py

@@ -161,6 +161,14 @@ class Tokenizer:
         return self.tokenizer.eos_token_id
 
     @cached_property
+    def transcribe(self) -> int:
+        return self._get_single_token_id("<|transcribe|>")
+
+    @cached_property
+    def translate(self) -> int:
+        return self._get_single_token_id("<|translate|>")
+
+    @cached_property
     def sot(self) -> int:
         return self._get_single_token_id("<|startoftranscript|>")
 

+ 15 - 15
whisper/transcribe.py

@@ -197,35 +197,35 @@ def transcribe(
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
             consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
             if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
+                if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
+                    consecutive = consecutive.tolist() + [len(tokens)]
                 last_slice = 0
                 for current_slice in consecutive:
                     sliced_tokens = tokens[last_slice:current_slice]
-                    start_timestamp_position = (
-                        sliced_tokens[0].item() - tokenizer.timestamp_begin
-                    )
-                    end_timestamp_position = (
-                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
-                    )
+                    start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
+                    end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
                     add_segment(
-                        start=timestamp_offset + start_timestamp_position * time_precision,
-                        end=timestamp_offset + end_timestamp_position * time_precision,
+                        start=timestamp_offset + start_timestamp_pos * time_precision,
+                        end=timestamp_offset + end_timestamp_pos * time_precision,
                         text_tokens=sliced_tokens[1:-1],
                         result=result,
                     )
                     last_slice = current_slice
-                last_timestamp_position = (
-                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin
-                )
-                seek += last_timestamp_position * input_stride
+                if ended_with_single_timestamp:
+                    # single timestamp at the end means no speech after the last timestamp.
+                    seek += segment.shape[-1]
+                else:
+                    # otherwise, ignore the unfinished segment and seek to the last timestamp
+                    last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
+                    seek += last_timestamp_pos * input_stride
                 all_tokens.extend(tokens[: last_slice + 1].tolist())
             else:
                 duration = segment_duration
                 timestamps = tokens[timestamp_tokens.nonzero().flatten()]
                 if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
                     # no consecutive timestamps but it has a timestamp; use the last one.
-                    # single timestamp at the end means no speech after the last timestamp.
-                    last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
-                    duration = last_timestamp_position * time_precision
+                    last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
+                    duration = last_timestamp_pos * time_precision
 
                 add_segment(
                     start=timestamp_offset,