Skip to content

Commit 5fa4356

Browse files
committed
miscellaneous improvements
1 parent ff6cbfd commit 5fa4356

File tree

9 files changed

+278
-116
lines changed

9 files changed

+278
-116
lines changed

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import pytest
55

66

7+
def pytest_configure(config):
8+
config.addinivalue_line("markers", "requires_cuda")
9+
10+
711
@pytest.fixture
812
def random():
913
rand.seed(42)

tests/test_timing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
(10, 20), (32, 16), (123, 1500), (234, 189),
1111
]
1212
shapes = [
13-
(4, 5, 20, 345), (6, 12, 240, 512),
13+
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
1414
]
1515

1616

@@ -65,7 +65,12 @@ def test_median_filter(shape):
6565

6666
for filter_width in [3, 5, 7, 13]:
6767
filtered = median_filter(x, filter_width)
68-
scipy_filtered = scipy.ndimage.median_filter(x, (1, 1, 1, filter_width), mode="nearest")
68+
69+
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
70+
pad_width = filter_width // 2
71+
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
72+
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
73+
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
6974

7075
assert np.allclose(filtered, scipy_filtered)
7176

tests/test_transcribe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def test_transcribe(model_name: str):
2525
for segment in result["segments"]:
2626
for timing in segment["words"]:
2727
assert timing["start"] < timing["end"]
28-
if timing["word"].strip() == "Americans":
29-
assert timing["start"] <= 1.75
30-
assert timing["end"] >= 2.05
28+
if timing["word"].strip(" ,") == "Americans":
29+
assert timing["start"] <= 1.8
30+
assert timing["end"] >= 1.8
31+
print(timing)
3132
timing_checked = True
3233

3334
assert timing_checked

whisper/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@
2929
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
3030
}
3131

32+
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
33+
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
34+
_ALIGNMENT_HEADS = {
35+
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
36+
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
37+
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
38+
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
39+
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
40+
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
41+
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
42+
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
43+
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
44+
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
45+
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
46+
}
47+
48+
3249

3350
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
3451
os.makedirs(root, exist_ok=True)
@@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
106123

107124
if name in _MODELS:
108125
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
126+
alignment_heads = _ALIGNMENT_HEADS[name]
109127
elif os.path.isfile(name):
110128
checkpoint_file = open(name, "rb").read() if in_memory else name
129+
alignment_heads = None
111130
else:
112131
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
113132

@@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
119138
model = Whisper(dims)
120139
model.load_state_dict(checkpoint["model_state_dict"])
121140

141+
if alignment_heads is not None:
142+
model.set_alignment_heads(alignment_heads)
143+
122144
return model.to(device)

whisper/model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass
22
from typing import Dict
33
from typing import Iterable, Optional
4-
4+
import gzip
5+
import base64
56
import numpy as np
67
import torch
78
import torch.nn.functional as F
@@ -213,6 +214,15 @@ def __init__(self, dims: ModelDimensions):
213214
self.dims.n_text_head,
214215
self.dims.n_text_layer,
215216
)
217+
# use the last half layers for alignment by default; see `set_alignment_heads()` below
218+
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
219+
all_heads[self.dims.n_text_layer // 2:] = True
220+
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
221+
222+
def set_alignment_heads(self, dump: bytes):
223+
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
224+
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
225+
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
216226

217227
def embed_audio(self, mel: torch.Tensor):
218228
return self.encoder(mel)

whisper/timing.py

Lines changed: 128 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from typing import List, TYPE_CHECKING
1+
import time
2+
import subprocess
3+
import warnings
24
from dataclasses import dataclass
5+
from typing import List, TYPE_CHECKING
6+
37
import numba
48
import numpy as np
59
import torch
@@ -14,17 +18,31 @@
1418

1519
def median_filter(x: torch.Tensor, filter_width: int):
1620
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
17-
assert 3 <= x.ndim <= 4, "`median_filter()` is implemented for only 3D or 4D tensors"
21+
if (ndim := x.ndim) <= 2: # `F.pad` does not support 1D or 2D inputs for reflect padding
22+
x = x[None, None, :]
23+
1824
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
1925

20-
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode='replicate')
26+
result = None
27+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
2128
if x.is_cuda:
22-
from .triton_ops import median_filter_cuda
23-
return median_filter_cuda(x, filter_width)
29+
try:
30+
from .triton_ops import median_filter_cuda
31+
result = median_filter_cuda(x, filter_width)
32+
except (RuntimeError, subprocess.CalledProcessError):
33+
warnings.warn(
34+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
35+
"falling back to a slower median kernel implementation..."
36+
)
37+
38+
if result is None:
39+
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
40+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
2441

25-
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
26-
return x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
42+
if ndim <= 2:
43+
result = result[0, 0]
2744

45+
return result
2846

2947
@numba.jit
3048
def backtrace(trace: np.ndarray):
@@ -108,17 +126,24 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
108126

109127
def dtw(x: torch.Tensor) -> np.ndarray:
110128
if x.is_cuda:
111-
return dtw_cuda(x)
129+
try:
130+
return dtw_cuda(x)
131+
except (RuntimeError, subprocess.CalledProcessError):
132+
warnings.warn(
133+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
134+
"falling back to a slower DTW implementation..."
135+
)
112136

113137
return dtw_cpu(x.double().cpu().numpy())
114138

115139

116140
@dataclass
117-
class Alignment:
118-
words: List[str]
119-
word_tokens: List[List[int]]
120-
start_times: np.ndarray
121-
end_times: np.ndarray
141+
class WordTiming:
142+
word: str
143+
tokens: List[int]
144+
start: float
145+
end: float
146+
probability: float
122147

123148

124149
def find_alignment(
@@ -128,16 +153,14 @@ def find_alignment(
128153
mel: torch.Tensor,
129154
num_frames: int,
130155
*,
131-
max_qk_layers: int = 6,
132156
medfilt_width: int = 7,
133157
qk_scale: float = 1.0,
134-
) -> Alignment:
158+
) -> List[WordTiming]:
135159
tokens = torch.tensor(
136160
[
137161
*tokenizer.sot_sequence,
138-
tokenizer.timestamp_begin,
162+
tokenizer.no_timestamps,
139163
*text_tokens,
140-
tokenizer.timestamp_begin + num_frames // 2,
141164
tokenizer.eot,
142165
]
143166
).to(model.device)
@@ -146,78 +169,132 @@ def find_alignment(
146169
QKs = [None] * model.dims.n_text_layer
147170
hooks = [
148171
block.cross_attn.register_forward_hook(
149-
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
172+
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
150173
)
151174
for i, block in enumerate(model.decoder.blocks)
152175
]
153176

154177
with torch.no_grad():
155-
model(mel.unsqueeze(0), tokens.unsqueeze(0))
178+
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
179+
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
180+
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
156181

157182
for hook in hooks:
158183
hook.remove()
159184

160-
weights = torch.cat(QKs[-max_qk_layers:]) # layers * heads * tokens * frames
161-
weights = weights[:, :, :, : num_frames // 2]
162-
weights = median_filter(weights, medfilt_width)
185+
# heads * tokens * frames
186+
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
187+
weights = weights[:, :, : num_frames // 2]
163188
weights = (weights * qk_scale).softmax(dim=-1)
164-
weights = weights / weights.norm(dim=-2, keepdim=True)
165-
matrix = weights.mean(axis=(0, 1)).neg()
189+
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
190+
weights = (weights - mean) / std
191+
weights = median_filter(weights, medfilt_width)
192+
193+
matrix = weights.mean(axis=0)
194+
matrix = matrix[len(tokenizer.sot_sequence):-1]
195+
text_indices, time_indices = dtw(-matrix)
166196

167-
text_indices, time_indices = dtw(matrix)
197+
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
198+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
168199

169200
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
170201
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
171-
172-
if tokenizer.language in {"zh", "ja", "th", "lo", "my"}:
173-
# These languages don't typically use spaces, so it is difficult to split words
174-
# without morpheme analysis. Here, we instead split words at any
175-
# position where the tokens are decoded as valid unicode points
176-
words, word_tokens = tokenizer.split_tokens_on_unicode(tokens[1:].tolist())
177-
else:
178-
words, word_tokens = tokenizer.split_tokens_on_spaces(tokens[1:].tolist())
179-
180-
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
181202
start_times = jump_times[word_boundaries[:-1]]
182203
end_times = jump_times[word_boundaries[1:]]
204+
word_probabilities = [
205+
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
206+
]
207+
208+
# hack: ensure the first and second word is not longer than twice the median word duration.
209+
# a better segmentation algorithm based on VAD should be able to replace this.
210+
word_durations = end_times - start_times
211+
word_durations = word_durations[word_durations.nonzero()]
212+
if len(word_durations) > 0:
213+
median_duration = np.median(word_durations)
214+
max_duration = median_duration * 2
215+
if len(word_durations) >= 2 and word_durations[1] > max_duration:
216+
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
217+
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
218+
start_times[0] = max(0, end_times[0] - max_duration)
219+
220+
return [
221+
WordTiming(word, tokens, start, end, probability)
222+
for word, tokens, start, end, probability in zip(
223+
words, word_tokens, start_times, end_times, word_probabilities
224+
)
225+
]
226+
183227

184-
return Alignment(words, word_tokens, start_times, end_times)
228+
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
229+
# merge prepended punctuations
230+
i = len(alignment) - 2
231+
j = len(alignment) - 1
232+
while i >= 0:
233+
previous = alignment[i]
234+
following = alignment[j]
235+
if previous.word.startswith(" ") and previous.word.strip() in prepended:
236+
# prepend it to the following word
237+
following.word = previous.word + following.word
238+
following.tokens = previous.tokens + following.tokens
239+
previous.word = ""
240+
previous.tokens = []
241+
else:
242+
j = i
243+
i -= 1
244+
245+
# merge appended punctuations
246+
i = 0
247+
j = 1
248+
while j < len(alignment):
249+
previous = alignment[i]
250+
following = alignment[j]
251+
if not previous.word.endswith(" ") and following.word in appended:
252+
# append it to the previous word
253+
previous.word = previous.word + following.word
254+
previous.tokens = previous.tokens + following.tokens
255+
following.word = ""
256+
following.tokens = []
257+
else:
258+
i = j
259+
j += 1
185260

186261

187262
def add_word_timestamps(
263+
*,
188264
segments: List[dict],
189265
model: "Whisper",
190266
tokenizer: Tokenizer,
191267
mel: torch.Tensor,
192268
num_frames: int,
269+
prepend_punctuations: str = "\"\'“¿([{-",
270+
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
193271
**hyperparams,
194272
):
195273
if len(segments) == 0:
196274
return
197275

198276
text_tokens = [t for segment in segments for t in segment["tokens"]]
199277
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
278+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
200279

201280
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
202-
alignment.start_times = time_offset + alignment.start_times
203-
alignment.end_times = time_offset + alignment.end_times
204-
205281
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
206-
token_sources: List[int] = [None] * len(tokenizer.sot_sequence) + list(token_sources)
207282

208283
for segment in segments:
209284
segment["words"] = []
210285

211-
word_boundaries = np.pad(np.cumsum([len(t) for t in alignment.word_tokens]), (1, 0))
212-
for i, (word, start, end) in enumerate(zip(alignment.words, alignment.start_times, alignment.end_times)):
213-
if word.startswith("<|") or word.strip() in ".,!?、。": # TODO: expand
214-
continue
215-
216-
segment = segments[token_sources[word_boundaries[i]]]
217-
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))
286+
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
287+
for i, timing in enumerate(alignment):
288+
if timing.word:
289+
segment = segments[token_sources[word_boundaries[i]]]
290+
start = round(time_offset + timing.start, 2)
291+
end = round(time_offset + timing.end, 2)
292+
segment["words"].append(
293+
dict(word=timing.word, start=start, end=end, probability=timing.probability)
294+
)
218295

219-
# adjust the segment-level timestamps based on the word-level timestamps
220296
for segment in segments:
221-
if len(segment["words"]) > 0:
222-
segment["start"] = segment["words"][0]["start"]
223-
segment["end"] = segment["words"][-1]["end"]
297+
if len(words := segment["words"]) > 0:
298+
# adjust the segment-level timestamps based on the word-level timestamps
299+
segment["start"] = words[0]["start"]
300+
segment["end"] = words[-1]["end"]

0 commit comments

Comments
 (0)