Skip to content

Commit 742d2f4

Browse files
committed
numba implementation for dtw, replacing dtw-python
1 parent cfd2b81 commit 742d2f4

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
numba
12
numpy
23
torch
34
tqdm
45
more-itertools
56
transformers>=4.19.0
67
ffmpeg-python==0.2.0
7-
dtw-python==1.3.0

whisper/timing.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, TYPE_CHECKING
22

3+
import numba
34
import numpy as np
45
import torch
56
import torch.nn.functional as F
@@ -21,6 +22,48 @@ def median_filter(x: torch.Tensor, filter_width: int):
2122
return slices.median(dim=-1).values
2223

2324

25+
@numba.jit(nopython=True, parallel=True)
26+
def dtw(x: np.ndarray):
27+
N, M = x.shape
28+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
29+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
30+
31+
i, j = 0, 0
32+
cost[0, 0] = 0
33+
for j in range(1, M + 1):
34+
for i in range(1, N + 1):
35+
c0 = cost[i - 1, j - 1]
36+
c1 = cost[i - 1, j]
37+
c2 = cost[i, j - 1]
38+
39+
if c0 < c1 and c0 < c2:
40+
c, t = c0, 0
41+
elif c1 < c0 and c1 < c2:
42+
c, t = c1, 1
43+
else:
44+
c, t = c2, 2
45+
46+
cost[i, j] = x[i - 1, j - 1] + c
47+
trace[i, j] = t
48+
49+
result = []
50+
while i > 0 and j > 0:
51+
result.append((i - 1, j - 1))
52+
53+
if trace[i, j] == 0:
54+
i -= 1
55+
j -= 1
56+
elif trace[i, j] == 1:
57+
i -= 1
58+
elif trace[i, j] == 2:
59+
j -= 1
60+
else:
61+
raise ValueError("Unexpected P[i, j]")
62+
63+
result = np.array(result)
64+
return result[::-1, :].T
65+
66+
2467
def add_word_timestamps(
2568
model: "Whisper",
2669
tokenizer: Tokenizer,
@@ -34,8 +77,6 @@ def add_word_timestamps(
3477
if len(segments) == 0:
3578
return
3679

37-
from dtw import dtw
38-
3980
# install hooks on the cross attention layers to retrieve the attention weights
4081
QKs = [None] * model.dims.n_text_layer
4182
hooks = [
@@ -67,12 +108,11 @@ def add_word_timestamps(
67108
weights = (weights * qk_scale).softmax(dim=-1)
68109

69110
w = weights / weights.norm(dim=-2, keepdim=True)
70-
matrix = w.mean(axis=(0, 1)).neg().double().cpu().numpy()
71-
72-
alignment = dtw(matrix)
111+
matrix = w.mean(axis=(0, 1)).neg().cpu().numpy()
112+
text_indices, time_indices = dtw(matrix)
73113

74-
jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
75-
jump_times = alignment.index2s[jumps] / TOKENS_PER_SECOND
114+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
115+
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
76116

77117
if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
78118
# These languages don't typically use spaces, so it is difficult to split words

0 commit comments

Comments
 (0)