@@ -197,35 +197,35 @@ def add_segment(
197197 timestamp_tokens : torch .Tensor = tokens .ge (tokenizer .timestamp_begin )
198198 consecutive = torch .where (timestamp_tokens [:- 1 ] & timestamp_tokens [1 :])[0 ].add_ (1 )
199199 if len (consecutive ) > 0 : # if the output contains two consecutive timestamp tokens
200+ if ended_with_single_timestamp := timestamp_tokens [- 2 :].tolist () == [False , True ]:
201+ consecutive = consecutive .tolist () + [len (tokens )]
200202 last_slice = 0
201203 for current_slice in consecutive :
202204 sliced_tokens = tokens [last_slice :current_slice ]
203- start_timestamp_position = (
204- sliced_tokens [0 ].item () - tokenizer .timestamp_begin
205- )
206- end_timestamp_position = (
207- sliced_tokens [- 1 ].item () - tokenizer .timestamp_begin
208- )
205+ start_timestamp_pos = sliced_tokens [0 ].item () - tokenizer .timestamp_begin
206+ end_timestamp_pos = sliced_tokens [- 1 ].item () - tokenizer .timestamp_begin
209207 add_segment (
210- start = timestamp_offset + start_timestamp_position * time_precision ,
211- end = timestamp_offset + end_timestamp_position * time_precision ,
208+ start = timestamp_offset + start_timestamp_pos * time_precision ,
209+ end = timestamp_offset + end_timestamp_pos * time_precision ,
212210 text_tokens = sliced_tokens [1 :- 1 ],
213211 result = result ,
214212 )
215213 last_slice = current_slice
216- last_timestamp_position = (
217- tokens [last_slice - 1 ].item () - tokenizer .timestamp_begin
218- )
219- seek += last_timestamp_position * input_stride
214+ if ended_with_single_timestamp :
215+ # single timestamp at the end means no speech after the last timestamp.
216+ seek += segment .shape [- 1 ]
217+ else :
218+ # otherwise, ignore the unfinished segment and seek to the last timestamp
219+ last_timestamp_pos = tokens [last_slice - 1 ].item () - tokenizer .timestamp_begin
220+ seek += last_timestamp_pos * input_stride
220221 all_tokens .extend (tokens [: last_slice + 1 ].tolist ())
221222 else :
222223 duration = segment_duration
223224 timestamps = tokens [timestamp_tokens .nonzero ().flatten ()]
224225 if len (timestamps ) > 0 and timestamps [- 1 ].item () != tokenizer .timestamp_begin :
225226 # no consecutive timestamps but it has a timestamp; use the last one.
226- # single timestamp at the end means no speech after the last timestamp.
227- last_timestamp_position = timestamps [- 1 ].item () - tokenizer .timestamp_begin
228- duration = last_timestamp_position * time_precision
227+ last_timestamp_pos = timestamps [- 1 ].item () - tokenizer .timestamp_begin
228+ duration = last_timestamp_pos * time_precision
229229
230230 add_segment (
231231 start = timestamp_offset ,
0 commit comments