Skip to content

Commit eab8d92

Browse files
authored
Decoding improvements (#1033)
* suppress task tokens (transcribe/translate) * not ignoring the last segment ending with one timestamp
1 parent 3e1780f commit eab8d92

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

whisper/decoding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,13 @@ def _get_suppress_tokens(self) -> Tuple[int]:
549549
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
550550

551551
suppress_tokens.extend(
552-
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
552+
[
553+
self.tokenizer.transcribe,
554+
self.tokenizer.translate,
555+
self.tokenizer.sot,
556+
self.tokenizer.sot_prev,
557+
self.tokenizer.sot_lm
558+
]
553559
)
554560
if self.tokenizer.no_speech is not None:
555561
# no-speech probability is collected separately

whisper/tokenizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,14 @@ def decode_with_timestamps(self, tokens) -> str:
160160
def eot(self) -> int:
161161
return self.tokenizer.eos_token_id
162162

163+
@cached_property
164+
def transcribe(self) -> int:
165+
return self._get_single_token_id("<|transcribe|>")
166+
167+
@cached_property
168+
def translate(self) -> int:
169+
return self._get_single_token_id("<|translate|>")
170+
163171
@cached_property
164172
def sot(self) -> int:
165173
return self._get_single_token_id("<|startoftranscript|>")

whisper/transcribe.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,35 +197,35 @@ def add_segment(
197197
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
198198
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
199199
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
200+
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
201+
consecutive = consecutive.tolist() + [len(tokens)]
200202
last_slice = 0
201203
for current_slice in consecutive:
202204
sliced_tokens = tokens[last_slice:current_slice]
203-
start_timestamp_position = (
204-
sliced_tokens[0].item() - tokenizer.timestamp_begin
205-
)
206-
end_timestamp_position = (
207-
sliced_tokens[-1].item() - tokenizer.timestamp_begin
208-
)
205+
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
206+
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
209207
add_segment(
210-
start=timestamp_offset + start_timestamp_position * time_precision,
211-
end=timestamp_offset + end_timestamp_position * time_precision,
208+
start=timestamp_offset + start_timestamp_pos * time_precision,
209+
end=timestamp_offset + end_timestamp_pos * time_precision,
212210
text_tokens=sliced_tokens[1:-1],
213211
result=result,
214212
)
215213
last_slice = current_slice
216-
last_timestamp_position = (
217-
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
218-
)
219-
seek += last_timestamp_position * input_stride
214+
if ended_with_single_timestamp:
215+
# single timestamp at the end means no speech after the last timestamp.
216+
seek += segment.shape[-1]
217+
else:
218+
# otherwise, ignore the unfinished segment and seek to the last timestamp
219+
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
220+
seek += last_timestamp_pos * input_stride
220221
all_tokens.extend(tokens[: last_slice + 1].tolist())
221222
else:
222223
duration = segment_duration
223224
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
224225
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
225226
# no consecutive timestamps but it has a timestamp; use the last one.
226-
# single timestamp at the end means no speech after the last timestamp.
227-
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
228-
duration = last_timestamp_position * time_precision
227+
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
228+
duration = last_timestamp_pos * time_precision
229229

230230
add_segment(
231231
start=timestamp_offset,

0 commit comments

Comments
 (0)