Quellcode durchsuchen

suppress generating non-timestamp tokens at the beginning (#532)

jumon vor 2 Jahren
Ursprung
Commit
76148a56c5
1 geänderte Dateien mit 8 neuen und 4 gelöschten Zeilen
  1. 8 4
      whisper/decoding.py

+ 8 - 4
whisper/decoding.py

@@ -423,10 +423,14 @@ class ApplyTimestampRules(LogitFilter):
                 else:  # cannot be normal text tokens
                     logits[k, : self.tokenizer.eot] = -np.inf
 
-        # apply the `max_initial_timestamp` option
-        if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
-            last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
-            logits[:, last_allowed + 1 :] = -np.inf
+        if tokens.shape[1] == self.sample_begin:
+            # suppress generating non-timestamp tokens at the beginning
+            logits[:, : self.tokenizer.timestamp_begin] = -np.inf
+
+            # apply the `max_initial_timestamp` option
+            if self.max_initial_timestamp_index is not None:
+                last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+                logits[:, last_allowed + 1 :] = -np.inf
 
         # if sum of probability over timestamps is above any other token, sample timestamp
         logprobs = F.log_softmax(logits.float(), dim=-1)