@@ -146,7 +146,7 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
146146 initial_prompt_tokens = []
147147
148148 def add_segment (
149- * , start : float , end : float , text_tokens : torch .Tensor , result : DecodingResult , encoder_embeddings , decoder_embeddings
149+ * , start : float , end : float , text_tokens : torch .Tensor , result : DecodingResult
150150 ):
151151 text = tokenizer .decode ([token for token in text_tokens if token < tokenizer .eot ])
152152 if len (text .strip ()) == 0 : # skip empty text output
@@ -164,8 +164,6 @@ def add_segment(
164164 "avg_logprob" : result .avg_logprob ,
165165 "compression_ratio" : result .compression_ratio ,
166166 "no_speech_prob" : result .no_speech_prob ,
167- "encoder_embeddings" :encoder_embeddings ,
168- "decoder_embeddings" :decoder_embeddings
169167 }
170168 )
171169 if verbose :
@@ -199,59 +197,41 @@ def add_segment(
199197 timestamp_tokens : torch .Tensor = tokens .ge (tokenizer .timestamp_begin )
200198 consecutive = torch .where (timestamp_tokens [:- 1 ] & timestamp_tokens [1 :])[0 ].add_ (1 )
201199 if len (consecutive ) > 0 : # if the output contains two consecutive timestamp tokens
200+ if ended_with_single_timestamp := timestamp_tokens [- 2 :].tolist () == [False , True ]:
201+ consecutive = consecutive .tolist () + [len (tokens )]
202202 last_slice = 0
203203 for current_slice in consecutive :
204204 sliced_tokens = tokens [last_slice :current_slice ]
205- start_timestamp_position = (
206- sliced_tokens [0 ].item () - tokenizer .timestamp_begin
207- )
208- end_timestamp_position = min (
209- sliced_tokens [- 1 ].item () - tokenizer .timestamp_begin ,
210- np .ceil ((num_frames - seek ) / input_stride ) - 1
211- )
212- encoder_embeddings = result .encoder_embeddings [:, :,
213- start_timestamp_position :int (end_timestamp_position )]
214- decoder_embeddings = result .decoder_embeddings [:,:, int (last_slice )+ 1 :int (current_slice )- 1 ]
215-
205+ start_timestamp_pos = sliced_tokens [0 ].item () - tokenizer .timestamp_begin
206+ end_timestamp_pos = sliced_tokens [- 1 ].item () - tokenizer .timestamp_begin
216207 add_segment (
217- start = timestamp_offset + start_timestamp_position * time_precision ,
218- end = timestamp_offset + end_timestamp_position * time_precision ,
208+ start = timestamp_offset + start_timestamp_pos * time_precision ,
209+ end = timestamp_offset + end_timestamp_pos * time_precision ,
219210 text_tokens = sliced_tokens [1 :- 1 ],
220211 result = result ,
221- encoder_embeddings = encoder_embeddings ,
222- decoder_embeddings = decoder_embeddings
223212 )
224213 last_slice = current_slice
225- last_timestamp_position = (
226- tokens [last_slice - 1 ].item () - tokenizer .timestamp_begin
227- )
228- seek += last_timestamp_position * input_stride
214+ if ended_with_single_timestamp :
215+ # single timestamp at the end means no speech after the last timestamp.
216+ seek += segment .shape [- 1 ]
217+ else :
218+ # otherwise, ignore the unfinished segment and seek to the last timestamp
219+ last_timestamp_pos = tokens [last_slice - 1 ].item () - tokenizer .timestamp_begin
220+ seek += last_timestamp_pos * input_stride
229221 all_tokens .extend (tokens [: last_slice + 1 ].tolist ())
230222 else :
231223 duration = segment_duration
232224 timestamps = tokens [timestamp_tokens .nonzero ().flatten ()]
233225 if len (timestamps ) > 0 and timestamps [- 1 ].item () != tokenizer .timestamp_begin :
234226 # no consecutive timestamps but it has a timestamp; use the last one.
235- # single timestamp at the end means no speech after the last timestamp.
236- last_timestamp_position = min (
237- timestamps [- 1 ].item () - tokenizer .timestamp_begin ,
238- np .ceil ((num_frames - seek ) / input_stride ) - 1
239- )
240- duration = last_timestamp_position * time_precision
227+ last_timestamp_pos = timestamps [- 1 ].item () - tokenizer .timestamp_begin
228+ duration = last_timestamp_pos * time_precision
241229
242- start_timestamp_position = (
243- timestamps [0 ].item () - tokenizer .timestamp_begin
244- )
245- encoder_embeddings = result .encoder_embeddings [:, :,
246- start_timestamp_position :int (last_timestamp_position )]
247- decoder_embeddings = result .decoder_embeddings [:, :, 1 :- 1 ]
248230 add_segment (
249231 start = timestamp_offset ,
250232 end = timestamp_offset + duration ,
251233 text_tokens = tokens ,
252234 result = result ,
253- encoder_embeddings = encoder_embeddings ,
254- decoder_embeddings = decoder_embeddings
255235 )
256236
257237 seek += segment .shape [- 1 ]
0 commit comments