Skip to content

Commit f999e5a

Browse files
authored
Update decoding.py
Changes from openai/whisper#914
1 parent 17982ad commit f999e5a

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

whisperx/decoding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ def apply(self, logits: Tensor, tokens: Tensor):
413413

414414
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
415415
for k in range(tokens.shape[0]):
416-
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
416+
sampled_tokens = tokens[k, self.sample_begin :]
417+
seq = [t for t in sampled_tokens.tolist()]
417418
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
418419
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
419420

@@ -422,6 +423,11 @@ def apply(self, logits: Tensor, tokens: Tensor):
422423
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
423424
else: # cannot be normal text tokens
424425
logits[k, : self.tokenizer.eot] = -np.inf
426+
427+
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
428+
if timestamps.numel() > 0:
429+
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
430+
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
425431

426432
if tokens.shape[1] == self.sample_begin:
427433
# suppress generating non-timestamp tokens at the beginning

0 commit comments

Comments
 (0)