Skip to content

Commit e18f40f

Browse files
andrewchernyhjongwook
authored andcommitted
Fix infinite loop caused by incorrect timestamp tokens prediction (openai#914)
* Fix infinite loop caused by incorrect timestamp tokens prediction openai#810 * Update decoding.py --------- Co-authored-by: Jong Wook Kim <jongwook@openai.com>
1 parent 7de7967 commit e18f40f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

whisper/decoding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,8 @@ def apply(self, logits: Tensor, tokens: Tensor):
412412

413413
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
414414
for k in range(tokens.shape[0]):
415-
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
415+
sampled_tokens = tokens[k, self.sample_begin :]
416+
seq = [t for t in sampled_tokens.tolist()]
416417
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
417418
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
418419

@@ -422,6 +423,11 @@ def apply(self, logits: Tensor, tokens: Tensor):
422423
else: # cannot be normal text tokens
423424
logits[k, : self.tokenizer.eot] = -np.inf
424425

426+
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
427+
if timestamps.numel() > 0:
428+
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
429+
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
430+
425431
if tokens.shape[1] == self.sample_begin:
426432
# suppress generating non-timestamp tokens at the beginning
427433
logits[:, : self.tokenizer.timestamp_begin] = -np.inf

0 commit comments

Comments
 (0)