|
@@ -423,10 +423,14 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
else: # cannot be normal text tokens
|
|
|
logits[k, : self.tokenizer.eot] = -np.inf
|
|
|
|
|
|
- # apply the `max_initial_timestamp` option
|
|
|
- if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
|
|
- last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
|
- logits[:, last_allowed + 1 :] = -np.inf
|
|
|
+ if tokens.shape[1] == self.sample_begin:
|
|
|
+ # suppress generating non-timestamp tokens at the beginning
|
|
|
+ logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
+
|
|
|
+ # apply the `max_initial_timestamp` option
|
|
|
+ if self.max_initial_timestamp_index is not None:
|
|
|
+ last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
|
+ logits[:, last_allowed + 1 :] = -np.inf
|
|
|
|
|
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
|
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|