|
@@ -469,9 +469,7 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
]
|
|
|
if timestamps.numel() > 0:
|
|
|
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
|
|
- logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
|
|
-
|
|
|
- # to force that timestamps are strictly increasing
|
|
|
+ # also force each segment to have a nonzero length, to prevent infinite looping
|
|
|
if last_was_timestamp and not penultimate_was_timestamp:
|
|
|
timestamp_last = timestamps[-1]
|
|
|
else:
|