Skip to content

Commit f6cf30b

Browse files
jongwookabyesilyurt
authored andcommitted
fix all_tokens handling that caused more repetitions and discrepancy in JSON (openai#1060)
1 parent 7ebe2a4 commit f6cf30b

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

tests/test_transcribe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_transcribe():
1818
audio_path, language=language, temperature=0.0, word_timestamps=True
1919
)
2020
assert result["language"] == "en"
21+
assert result["text"] == "".join([s["text"] for s in result["segments"]])
2122

2223
transcription = result["text"].lower()
2324
assert "my fellow americans" in transcription

whisper/timing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def add_word_timestamps(
290290
if len(segments) == 0:
291291
return
292292

293-
text_tokens = [t for segment in segments for t in segment["tokens"]]
293+
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
294294
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
295295
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
296296

whisper/transcribe.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,14 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
201201
def new_segment(
202202
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
203203
):
204-
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
204+
tokens = tokens.tolist()
205+
text_tokens = [token for token in tokens if token < tokenizer.eot]
205206
return {
206-
"id": len(all_segments),
207207
"seek": seek,
208208
"start": start,
209209
"end": end,
210210
"text": tokenizer.decode(text_tokens),
211-
"tokens": text_tokens,
211+
"tokens": tokens,
212212
"temperature": result.temperature,
213213
"avg_logprob": result.avg_logprob,
214214
"compression_ratio": result.compression_ratio,
@@ -246,7 +246,6 @@ def new_segment(
246246

247247
previous_seek = seek
248248
current_segments = []
249-
current_tokens = []
250249

251250
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
252251
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -276,7 +275,6 @@ def new_segment(
276275
result=result,
277276
)
278277
)
279-
current_tokens.append(sliced_tokens.tolist())
280278
last_slice = current_slice
281279

282280
if single_timestamp_ending:
@@ -288,7 +286,6 @@ def new_segment(
288286
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
289287
)
290288
seek += last_timestamp_pos * input_stride
291-
all_tokens.extend(tokens[: last_slice + 1].tolist())
292289
else:
293290
duration = segment_duration
294291
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@@ -310,7 +307,6 @@ def new_segment(
310307
result=result,
311308
)
312309
)
313-
current_tokens.append(tokens.tolist())
314310
seek += segment_size
315311

316312
if not condition_on_previous_text or result.temperature > 0.5:
@@ -349,11 +345,17 @@ def new_segment(
349345
segment["text"] = ""
350346
segment["tokens"] = []
351347
segment["words"] = []
352-
current_tokens[i] = []
353348

354-
all_segments.extend(current_segments)
349+
all_segments.extend(
350+
[
351+
{"id": i, **segment}
352+
for i, segment in enumerate(
353+
current_segments, start=len(all_segments)
354+
)
355+
]
356+
)
355357
all_tokens.extend(
356-
[token for segment in current_tokens for token in segment]
358+
[token for segment in current_segments for token in segment["tokens"]]
357359
)
358360

359361
# update progress bar

0 commit comments

Comments
 (0)