|
@@ -471,6 +471,13 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
|
|
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
|
|
|
|
|
+ # to force that timestamps are strictly increasing
|
|
|
+ if last_was_timestamp and not penultimate_was_timestamp:
|
|
|
+ timestamp_last = timestamps[-1]
|
|
|
+ else:
|
|
|
+ timestamp_last = timestamps[-1] + 1
|
|
|
+ logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
|
|
+
|
|
|
if tokens.shape[1] == self.sample_begin:
|
|
|
# suppress generating non-timestamp tokens at the beginning
|
|
|
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|