Jelajahi Sumber

Fix infinite loop caused by incorrect timestamp tokens prediction (#914)

* Fix infinite loop caused by incorrect timestamp tokens prediction

https://github.com/openai/whisper/discussions/810

* Update decoding.py

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Andrey Chernykh 2 tahun lalu
induk
melakukan
7858aa9c08
1 mengubah file dengan 7 tambahan dan 1 penghapusan
  1. 7 1
      whisper/decoding.py

+ 7 - 1
whisper/decoding.py

@@ -412,7 +412,8 @@ class ApplyTimestampRules(LogitFilter):
 
         # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
         for k in range(tokens.shape[0]):
-            seq = [t for t in tokens[k, self.sample_begin :].tolist()]
+            sampled_tokens = tokens[k, self.sample_begin :]
+            seq = [t for t in sampled_tokens.tolist()]
             last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
             penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
 
@@ -422,6 +423,11 @@ class ApplyTimestampRules(LogitFilter):
                 else:  # cannot be normal text tokens
                     logits[k, : self.tokenizer.eot] = -np.inf
 
+            timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
+            if timestamps.numel() > 0:
+                # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
+                logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
+
         if tokens.shape[1] == self.sample_begin:
             # suppress generating non-timestamp tokens at the beginning
             logits[:, : self.tokenizer.timestamp_begin] = -np.inf