11from typing import List , TYPE_CHECKING
22
3+ import numba
34import numpy as np
45import torch
56import torch .nn .functional as F
@@ -21,6 +22,48 @@ def median_filter(x: torch.Tensor, filter_width: int):
2122 return slices .median (dim = - 1 ).values
2223
2324
25+ @numba .jit (nopython = True , parallel = True )
26+ def dtw (x : np .ndarray ):
27+ N , M = x .shape
28+ cost = np .ones ((N + 1 , M + 1 ), dtype = np .float32 ) * np .inf
29+ trace = - np .ones ((N + 1 , M + 1 ), dtype = np .float32 )
30+
31+ i , j = 0 , 0
32+ cost [0 , 0 ] = 0
33+ for j in range (1 , M + 1 ):
34+ for i in range (1 , N + 1 ):
35+ c0 = cost [i - 1 , j - 1 ]
36+ c1 = cost [i - 1 , j ]
37+ c2 = cost [i , j - 1 ]
38+
39+ if c0 < c1 and c0 < c2 :
40+ c , t = c0 , 0
41+ elif c1 < c0 and c1 < c2 :
42+ c , t = c1 , 1
43+ else :
44+ c , t = c2 , 2
45+
46+ cost [i , j ] = x [i - 1 , j - 1 ] + c
47+ trace [i , j ] = t
48+
49+ result = []
50+ while i > 0 and j > 0 :
51+ result .append ((i - 1 , j - 1 ))
52+
53+ if trace [i , j ] == 0 :
54+ i -= 1
55+ j -= 1
56+ elif trace [i , j ] == 1 :
57+ i -= 1
58+ elif trace [i , j ] == 2 :
59+ j -= 1
60+ else :
61+ raise ValueError ("Unexpected P[i, j]" )
62+
63+ result = np .array (result )
64+ return result [::- 1 , :].T
65+
66+
2467def add_word_timestamps (
2568 model : "Whisper" ,
2669 tokenizer : Tokenizer ,
@@ -34,8 +77,6 @@ def add_word_timestamps(
3477 if len (segments ) == 0 :
3578 return
3679
37- from dtw import dtw
38-
3980 # install hooks on the cross attention layers to retrieve the attention weights
4081 QKs = [None ] * model .dims .n_text_layer
4182 hooks = [
@@ -67,12 +108,11 @@ def add_word_timestamps(
67108 weights = (weights * qk_scale ).softmax (dim = - 1 )
68109
69110 w = weights / weights .norm (dim = - 2 , keepdim = True )
70- matrix = w .mean (axis = (0 , 1 )).neg ().double ().cpu ().numpy ()
71-
72- alignment = dtw (matrix )
111+ matrix = w .mean (axis = (0 , 1 )).neg ().cpu ().numpy ()
112+ text_indices , time_indices = dtw (matrix )
73113
74- jumps = np .pad (np .diff (alignment . index1s ), (1 , 0 ), constant_values = 1 ).astype (bool )
75- jump_times = alignment . index2s [jumps ] / TOKENS_PER_SECOND
114+ jumps = np .pad (np .diff (text_indices ), (1 , 0 ), constant_values = 1 ).astype (bool )
115+ jump_times = time_indices [jumps ] / TOKENS_PER_SECOND
76116
77117 if tokenizer .language in {"zh" , "ja" , "th" , "lo" , "my" }:
78118 # These languages don't typically use spaces, so it is difficult to split words
0 commit comments