Skip to content

Commit f5040f9

Browse files
jongwookilanit1997
authored andcommitted
Decoding improvements (openai#1033)
* suppress task tokens (transcribe/translate) * not ignoring the last segment ending with one timestamp
1 parent 368acfe commit f5040f9

File tree

3 files changed

+31
-37
lines changed

3 files changed

+31
-37
lines changed

whisper/decoding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,13 @@ def _get_suppress_tokens(self) -> Tuple[int]:
553553
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
554554

555555
suppress_tokens.extend(
556-
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
556+
[
557+
self.tokenizer.transcribe,
558+
self.tokenizer.translate,
559+
self.tokenizer.sot,
560+
self.tokenizer.sot_prev,
561+
self.tokenizer.sot_lm
562+
]
557563
)
558564
if self.tokenizer.no_speech is not None:
559565
# no-speech probability is collected separately

whisper/tokenizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,14 @@ def decode_with_timestamps(self, tokens) -> str:
160160
def eot(self) -> int:
161161
return self.tokenizer.eos_token_id
162162

163+
@cached_property
164+
def transcribe(self) -> int:
165+
return self._get_single_token_id("<|transcribe|>")
166+
167+
@cached_property
168+
def translate(self) -> int:
169+
return self._get_single_token_id("<|translate|>")
170+
163171
@cached_property
164172
def sot(self) -> int:
165173
return self._get_single_token_id("<|startoftranscript|>")

whisper/transcribe.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
146146
initial_prompt_tokens = []
147147

148148
def add_segment(
149-
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult, encoder_embeddings, decoder_embeddings
149+
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
150150
):
151151
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
152152
if len(text.strip()) == 0: # skip empty text output
@@ -164,8 +164,6 @@ def add_segment(
164164
"avg_logprob": result.avg_logprob,
165165
"compression_ratio": result.compression_ratio,
166166
"no_speech_prob": result.no_speech_prob,
167-
"encoder_embeddings":encoder_embeddings,
168-
"decoder_embeddings":decoder_embeddings
169167
}
170168
)
171169
if verbose:
@@ -199,59 +197,41 @@ def add_segment(
199197
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
200198
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
201199
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
200+
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
201+
consecutive = consecutive.tolist() + [len(tokens)]
202202
last_slice = 0
203203
for current_slice in consecutive:
204204
sliced_tokens = tokens[last_slice:current_slice]
205-
start_timestamp_position = (
206-
sliced_tokens[0].item() - tokenizer.timestamp_begin
207-
)
208-
end_timestamp_position = min(
209-
sliced_tokens[-1].item() - tokenizer.timestamp_begin,
210-
np.ceil((num_frames - seek) / input_stride) - 1
211-
)
212-
encoder_embeddings = result.encoder_embeddings[:, :,
213-
start_timestamp_position:int(end_timestamp_position)]
214-
decoder_embeddings = result.decoder_embeddings[:,:, int(last_slice)+1:int(current_slice)-1]
215-
205+
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
206+
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
216207
add_segment(
217-
start=timestamp_offset + start_timestamp_position * time_precision,
218-
end=timestamp_offset + end_timestamp_position * time_precision,
208+
start=timestamp_offset + start_timestamp_pos * time_precision,
209+
end=timestamp_offset + end_timestamp_pos * time_precision,
219210
text_tokens=sliced_tokens[1:-1],
220211
result=result,
221-
encoder_embeddings=encoder_embeddings,
222-
decoder_embeddings=decoder_embeddings
223212
)
224213
last_slice = current_slice
225-
last_timestamp_position = (
226-
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
227-
)
228-
seek += last_timestamp_position * input_stride
214+
if ended_with_single_timestamp:
215+
# single timestamp at the end means no speech after the last timestamp.
216+
seek += segment.shape[-1]
217+
else:
218+
# otherwise, ignore the unfinished segment and seek to the last timestamp
219+
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
220+
seek += last_timestamp_pos * input_stride
229221
all_tokens.extend(tokens[: last_slice + 1].tolist())
230222
else:
231223
duration = segment_duration
232224
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
233225
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
234226
# no consecutive timestamps but it has a timestamp; use the last one.
235-
# single timestamp at the end means no speech after the last timestamp.
236-
last_timestamp_position = min(
237-
timestamps[-1].item() - tokenizer.timestamp_begin,
238-
np.ceil((num_frames - seek) / input_stride) - 1
239-
)
240-
duration = last_timestamp_position * time_precision
227+
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
228+
duration = last_timestamp_pos * time_precision
241229

242-
start_timestamp_position = (
243-
timestamps[0].item() - tokenizer.timestamp_begin
244-
)
245-
encoder_embeddings = result.encoder_embeddings[:, :,
246-
start_timestamp_position:int(last_timestamp_position)]
247-
decoder_embeddings = result.decoder_embeddings[:, :, 1:-1]
248230
add_segment(
249231
start=timestamp_offset,
250232
end=timestamp_offset + duration,
251233
text_tokens=tokens,
252234
result=result,
253-
encoder_embeddings=encoder_embeddings,
254-
decoder_embeddings=decoder_embeddings
255235
)
256236

257237
seek += segment.shape[-1]

0 commit comments

Comments
 (0)