1- from typing import List , TYPE_CHECKING
1+ import time
2+ import subprocess
3+ import warnings
24from dataclasses import dataclass
5+ from typing import List , TYPE_CHECKING
6+
37import numba
48import numpy as np
59import torch
1418
1519def median_filter (x : torch .Tensor , filter_width : int ):
1620 """Apply a median filter of width `filter_width` along the last dimension of `x`"""
17- assert 3 <= x .ndim <= 4 , "`median_filter()` is implemented for only 3D or 4D tensors"
21+ if (ndim := x .ndim ) <= 2 : # `F.pad` does not support 1D or 2D inputs for reflect padding
22+ x = x [None , None , :]
23+
1824 assert filter_width > 0 and filter_width % 2 == 1 , "`filter_width` should be an odd number"
1925
20- x = F .pad (x , (filter_width // 2 , filter_width // 2 , 0 , 0 ), mode = 'replicate' )
26+ result = None
27+ x = F .pad (x , (filter_width // 2 , filter_width // 2 , 0 , 0 ), mode = "reflect" )
2128 if x .is_cuda :
22- from .triton_ops import median_filter_cuda
23- return median_filter_cuda (x , filter_width )
29+ try :
30+ from .triton_ops import median_filter_cuda
31+ result = median_filter_cuda (x , filter_width )
32+ except (RuntimeError , subprocess .CalledProcessError ):
33+ warnings .warn (
34+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
35+ "falling back to a slower median kernel implementation..."
36+ )
37+
38+ if result is None :
39+ # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
40+ result = x .unfold (- 1 , filter_width , 1 ).sort ()[0 ][..., filter_width // 2 ]
2441
25- # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
26- return x . unfold ( - 1 , filter_width , 1 ). sort ()[ 0 ][..., filter_width // 2 ]
42+ if ndim <= 2 :
43+ result = result [ 0 , 0 ]
2744
45+ return result
2846
2947@numba .jit
3048def backtrace (trace : np .ndarray ):
@@ -108,17 +126,24 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
108126
109127def dtw (x : torch .Tensor ) -> np .ndarray :
110128 if x .is_cuda :
111- return dtw_cuda (x )
129+ try :
130+ return dtw_cuda (x )
131+ except (RuntimeError , subprocess .CalledProcessError ):
132+ warnings .warn (
133+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
134+ "falling back to a slower DTW implementation..."
135+ )
112136
113137 return dtw_cpu (x .double ().cpu ().numpy ())
114138
115139
116140@dataclass
117- class Alignment :
118- words : List [str ]
119- word_tokens : List [List [int ]]
120- start_times : np .ndarray
121- end_times : np .ndarray
141+ class WordTiming :
142+ word : str
143+ tokens : List [int ]
144+ start : float
145+ end : float
146+ probability : float
122147
123148
124149def find_alignment (
@@ -128,16 +153,14 @@ def find_alignment(
128153 mel : torch .Tensor ,
129154 num_frames : int ,
130155 * ,
131- max_qk_layers : int = 6 ,
132156 medfilt_width : int = 7 ,
133157 qk_scale : float = 1.0 ,
134- ) -> Alignment :
158+ ) -> List [ WordTiming ] :
135159 tokens = torch .tensor (
136160 [
137161 * tokenizer .sot_sequence ,
138- tokenizer .timestamp_begin ,
162+ tokenizer .no_timestamps ,
139163 * text_tokens ,
140- tokenizer .timestamp_begin + num_frames // 2 ,
141164 tokenizer .eot ,
142165 ]
143166 ).to (model .device )
@@ -146,78 +169,132 @@ def find_alignment(
146169 QKs = [None ] * model .dims .n_text_layer
147170 hooks = [
148171 block .cross_attn .register_forward_hook (
149- lambda _ , ins , outs , index = i : QKs .__setitem__ (index , outs [- 1 ])
172+ lambda _ , ins , outs , index = i : QKs .__setitem__ (index , outs [- 1 ][ 0 ] )
150173 )
151174 for i , block in enumerate (model .decoder .blocks )
152175 ]
153176
154177 with torch .no_grad ():
155- model (mel .unsqueeze (0 ), tokens .unsqueeze (0 ))
178+ logits = model (mel .unsqueeze (0 ), tokens .unsqueeze (0 ))[0 ]
179+ token_probs = logits [len (tokenizer .sot_sequence ):, :tokenizer .eot ].softmax (dim = - 1 )
180+ text_token_probs = token_probs [np .arange (len (text_tokens )), text_tokens ].tolist ()
156181
157182 for hook in hooks :
158183 hook .remove ()
159184
160- weights = torch . cat ( QKs [ - max_qk_layers :]) # layers * heads * tokens * frames
161- weights = weights [:, :, :, : num_frames // 2 ]
162- weights = median_filter ( weights , medfilt_width )
185+ # heads * tokens * frames
186+ weights = torch . stack ([ QKs [ l ][ h ] for l , h in model . alignment_heads . indices (). T ])
187+ weights = weights [:, :, : num_frames // 2 ]
163188 weights = (weights * qk_scale ).softmax (dim = - 1 )
164- weights = weights / weights .norm (dim = - 2 , keepdim = True )
165- matrix = weights .mean (axis = (0 , 1 )).neg ()
189+ std , mean = torch .std_mean (weights , dim = - 2 , keepdim = True , unbiased = False )
190+ weights = (weights - mean ) / std
191+ weights = median_filter (weights , medfilt_width )
192+
193+ matrix = weights .mean (axis = 0 )
194+ matrix = matrix [len (tokenizer .sot_sequence ):- 1 ]
195+ text_indices , time_indices = dtw (- matrix )
166196
167- text_indices , time_indices = dtw (matrix )
197+ words , word_tokens = tokenizer .split_to_word_tokens (text_tokens + [tokenizer .eot ])
198+ word_boundaries = np .pad (np .cumsum ([len (t ) for t in word_tokens [:- 1 ]]), (1 , 0 ))
168199
169200 jumps = np .pad (np .diff (text_indices ), (1 , 0 ), constant_values = 1 ).astype (bool )
170201 jump_times = time_indices [jumps ] / TOKENS_PER_SECOND
171-
172- if tokenizer .language in {"zh" , "ja" , "th" , "lo" , "my" }:
173- # These languages don't typically use spaces, so it is difficult to split words
174- # without morpheme analysis. Here, we instead split words at any
175- # position where the tokens are decoded as valid unicode points
176- words , word_tokens = tokenizer .split_tokens_on_unicode (tokens [1 :].tolist ())
177- else :
178- words , word_tokens = tokenizer .split_tokens_on_spaces (tokens [1 :].tolist ())
179-
180- word_boundaries = np .pad (np .cumsum ([len (t ) for t in word_tokens ]), (1 , 0 ))
181202 start_times = jump_times [word_boundaries [:- 1 ]]
182203 end_times = jump_times [word_boundaries [1 :]]
204+ word_probabilities = [
205+ np .mean (text_token_probs [i :j ]) for i , j in zip (word_boundaries [:- 1 ], word_boundaries [1 :])
206+ ]
207+
208+ # hack: ensure the first and second word is not longer than twice the median word duration.
209+ # a better segmentation algorithm based on VAD should be able to replace this.
210+ word_durations = end_times - start_times
211+ word_durations = word_durations [word_durations .nonzero ()]
212+ if len (word_durations ) > 0 :
213+ median_duration = np .median (word_durations )
214+ max_duration = median_duration * 2
215+ if len (word_durations ) >= 2 and word_durations [1 ] > max_duration :
216+ end_times [0 ] = start_times [1 ] = max (end_times [2 ] / 2 , end_times [2 ] - max_duration )
217+ if len (word_durations ) >= 1 and end_times [0 ] - start_times [0 ] > max_duration :
218+ start_times [0 ] = max (0 , end_times [0 ] - max_duration )
219+
220+ return [
221+ WordTiming (word , tokens , start , end , probability )
222+ for word , tokens , start , end , probability in zip (
223+ words , word_tokens , start_times , end_times , word_probabilities
224+ )
225+ ]
226+
183227
184- return Alignment (words , word_tokens , start_times , end_times )
228+ def merge_punctuations (alignment : List [WordTiming ], prepended : str , appended : str ):
229+ # merge prepended punctuations
230+ i = len (alignment ) - 2
231+ j = len (alignment ) - 1
232+ while i >= 0 :
233+ previous = alignment [i ]
234+ following = alignment [j ]
235+ if previous .word .startswith (" " ) and previous .word .strip () in prepended :
236+ # prepend it to the following word
237+ following .word = previous .word + following .word
238+ following .tokens = previous .tokens + following .tokens
239+ previous .word = ""
240+ previous .tokens = []
241+ else :
242+ j = i
243+ i -= 1
244+
245+ # merge appended punctuations
246+ i = 0
247+ j = 1
248+ while j < len (alignment ):
249+ previous = alignment [i ]
250+ following = alignment [j ]
251+ if not previous .word .endswith (" " ) and following .word in appended :
252+ # append it to the previous word
253+ previous .word = previous .word + following .word
254+ previous .tokens = previous .tokens + following .tokens
255+ following .word = ""
256+ following .tokens = []
257+ else :
258+ i = j
259+ j += 1
185260
186261
187262def add_word_timestamps (
263+ * ,
188264 segments : List [dict ],
189265 model : "Whisper" ,
190266 tokenizer : Tokenizer ,
191267 mel : torch .Tensor ,
192268 num_frames : int ,
269+ prepend_punctuations : str = "\" \' “¿([{-" ,
270+ append_punctuations : str = "\" \' .。,,!!??::”)]}、" ,
193271 ** hyperparams ,
194272):
195273 if len (segments ) == 0 :
196274 return
197275
198276 text_tokens = [t for segment in segments for t in segment ["tokens" ]]
199277 alignment = find_alignment (model , tokenizer , text_tokens , mel , num_frames , ** hyperparams )
278+ merge_punctuations (alignment , prepend_punctuations , append_punctuations )
200279
201280 time_offset = segments [0 ]["seek" ] * HOP_LENGTH / SAMPLE_RATE
202- alignment .start_times = time_offset + alignment .start_times
203- alignment .end_times = time_offset + alignment .end_times
204-
205281 token_sources = np .repeat (np .arange (len (segments )), [len (s ["tokens" ]) for s in segments ])
206- token_sources : List [int ] = [None ] * len (tokenizer .sot_sequence ) + list (token_sources )
207282
208283 for segment in segments :
209284 segment ["words" ] = []
210285
211- word_boundaries = np .pad (np .cumsum ([len (t ) for t in alignment .word_tokens ]), (1 , 0 ))
212- for i , (word , start , end ) in enumerate (zip (alignment .words , alignment .start_times , alignment .end_times )):
213- if word .startswith ("<|" ) or word .strip () in ".,!?、。" : # TODO: expand
214- continue
215-
216- segment = segments [token_sources [word_boundaries [i ]]]
217- segment ["words" ].append (dict (word = word , start = round (start , 2 ), end = round (end , 2 )))
286+ word_boundaries = np .pad (np .cumsum ([len (w .tokens ) for w in alignment ]), (1 , 0 ))
287+ for i , timing in enumerate (alignment ):
288+ if timing .word :
289+ segment = segments [token_sources [word_boundaries [i ]]]
290+ start = round (time_offset + timing .start , 2 )
291+ end = round (time_offset + timing .end , 2 )
292+ segment ["words" ].append (
293+ dict (word = timing .word , start = start , end = end , probability = timing .probability )
294+ )
218295
219- # adjust the segment-level timestamps based on the word-level timestamps
220296 for segment in segments :
221- if len (segment ["words" ]) > 0 :
222- segment ["start" ] = segment ["words" ][0 ]["start" ]
223- segment ["end" ] = segment ["words" ][- 1 ]["end" ]
297+ if len (words := segment ["words" ]) > 0 :
298+ # adjust the segment-level timestamps based on the word-level timestamps
299+ segment ["start" ] = words [0 ]["start" ]
300+ segment ["end" ] = words [- 1 ]["end" ]
0 commit comments