@@ -201,14 +201,14 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
201201 def new_segment (
202202 * , start : float , end : float , tokens : torch .Tensor , result : DecodingResult
203203 ):
204- text_tokens = [token for token in tokens .tolist () if token < tokenizer .eot ]
204+ tokens = tokens .tolist ()
205+ text_tokens = [token for token in tokens if token < tokenizer .eot ]
205206 return {
206- "id" : len (all_segments ),
207207 "seek" : seek ,
208208 "start" : start ,
209209 "end" : end ,
210210 "text" : tokenizer .decode (text_tokens ),
211- "tokens" : text_tokens ,
211+ "tokens" : tokens ,
212212 "temperature" : result .temperature ,
213213 "avg_logprob" : result .avg_logprob ,
214214 "compression_ratio" : result .compression_ratio ,
@@ -246,7 +246,6 @@ def new_segment(
246246
247247 previous_seek = seek
248248 current_segments = []
249- current_tokens = []
250249
251250 timestamp_tokens : torch .Tensor = tokens .ge (tokenizer .timestamp_begin )
252251 single_timestamp_ending = timestamp_tokens [- 2 :].tolist () == [False , True ]
@@ -276,7 +275,6 @@ def new_segment(
276275 result = result ,
277276 )
278277 )
279- current_tokens .append (sliced_tokens .tolist ())
280278 last_slice = current_slice
281279
282280 if single_timestamp_ending :
@@ -288,7 +286,6 @@ def new_segment(
288286 tokens [last_slice - 1 ].item () - tokenizer .timestamp_begin
289287 )
290288 seek += last_timestamp_pos * input_stride
291- all_tokens .extend (tokens [: last_slice + 1 ].tolist ())
292289 else :
293290 duration = segment_duration
294291 timestamps = tokens [timestamp_tokens .nonzero ().flatten ()]
@@ -310,7 +307,6 @@ def new_segment(
310307 result = result ,
311308 )
312309 )
313- current_tokens .append (tokens .tolist ())
314310 seek += segment_size
315311
316312 if not condition_on_previous_text or result .temperature > 0.5 :
@@ -349,11 +345,17 @@ def new_segment(
349345 segment ["text" ] = ""
350346 segment ["tokens" ] = []
351347 segment ["words" ] = []
352- current_tokens [i ] = []
353348
354- all_segments .extend (current_segments )
349+ all_segments .extend (
350+ [
351+ {"id" : i , ** segment }
352+ for i , segment in enumerate (
353+ current_segments , start = len (all_segments )
354+ )
355+ ]
356+ )
355357 all_tokens .extend (
356- [token for segment in current_tokens for token in segment ]
358+ [token for segment in current_segments for token in segment [ "tokens" ] ]
357359 )
358360
359361 # update progress bar
0 commit comments