Skip to content

Commit c27cbc3

Browse files
jongwookryanheise
andauthored
word-level timestamps in transcribe() (openai#869)
* word-level timestamps in `transcribe()` * moving to `timing.py` * numba implementation for dtw, replacing dtw-python * triton implementation for dtw * add test for dtw implementations * triton implementation of median_filter * a simple word-level timestamps test * add scipy as dev dependency * installs an older version of Triton if CUDA < 11.4 * fix broken merge * loosen nvcc version match regex * find_alignment() function * miscellaneous improvements * skip median filtering when the input is too small * Expose punctuation options in cli and transcribe() (openai#973) * fix merge error * fix merge error 2 * annotating that word_timestamps is experimental --------- Co-authored-by: ryanheise <ryan@ryanheise.com>
1 parent f6d0264 commit c27cbc3

File tree

14 files changed

+769
-78
lines changed

14 files changed

+769
-78
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@ jobs:
2121
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
2222
- uses: actions/checkout@v2
2323
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
24-
- run: pip install pytest
25-
- run: pip install .
26-
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
24+
- run: pip install .["dev"]
25+
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
numba
12
numpy
23
torch
34
tqdm

setup.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34
import pkg_resources
45
from setuptools import setup, find_packages
@@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
910
return locals()["__version__"]
1011

1112

13+
requirements = []
14+
if sys.platform.startswith("linux"):
15+
triton_requirement = "triton>=2.0.0.dev20221202"
16+
try:
17+
import re
18+
import subprocess
19+
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
20+
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
21+
if (int(major), int(minor)) < (11, 4):
22+
# the last version supporting CUDA < 11.4
23+
triton_requirement = "triton==2.0.0.dev20221011"
24+
except (IndexError, OSError, subprocess.SubprocessError):
25+
pass
26+
requirements.append(triton_requirement)
27+
1228
setup(
1329
name="openai-whisper",
1430
py_modules=["whisper"],
@@ -22,7 +38,7 @@ def read_version(fname="whisper/version.py"):
2238
url="https://github.com/openai/whisper",
2339
license="MIT",
2440
packages=find_packages(exclude=["tests*"]),
25-
install_requires=[
41+
install_requires=requirements + [
2642
str(r)
2743
for r in pkg_resources.parse_requirements(
2844
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -32,5 +48,5 @@ def read_version(fname="whisper/version.py"):
3248
"console_scripts": ["whisper=whisper.transcribe:cli"],
3349
},
3450
include_package_data=True,
35-
extras_require={"dev": ["pytest"]},
51+
extras_require={"dev": ["pytest", "scipy"]},
3652
)

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import random as rand
2+
3+
import numpy
4+
import pytest
5+
6+
7+
def pytest_configure(config):
8+
config.addinivalue_line("markers", "requires_cuda")
9+
10+
11+
@pytest.fixture
12+
def random():
13+
rand.seed(42)
14+
numpy.random.seed(42)

tests/test_timing.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
import numpy as np
3+
import scipy.ndimage
4+
import torch
5+
6+
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
7+
8+
9+
sizes = [
10+
(10, 20), (32, 16), (123, 1500), (234, 189),
11+
]
12+
shapes = [
13+
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
14+
]
15+
16+
17+
@pytest.mark.parametrize("N, M", sizes)
18+
def test_dtw(N: int, M: int):
19+
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
20+
np.random.shuffle(steps)
21+
x = np.random.random((N, M)).astype(np.float32)
22+
23+
i, j, k = 0, 0, 0
24+
trace = []
25+
while True:
26+
x[i, j] -= 1
27+
trace.append((i, j))
28+
29+
if k == len(steps):
30+
break
31+
32+
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
33+
i += 1
34+
j += 1
35+
k += 2
36+
continue
37+
38+
if steps[k] == 0:
39+
i += 1
40+
if steps[k] == 1:
41+
j += 1
42+
k += 1
43+
44+
trace = np.array(trace).T
45+
dtw_trace = dtw_cpu(x)
46+
47+
assert np.allclose(trace, dtw_trace)
48+
49+
50+
@pytest.mark.requires_cuda
51+
@pytest.mark.parametrize("N, M", sizes)
52+
def test_dtw_cuda_equivalence(N: int, M: int):
53+
x_numpy = np.random.randn(N, M).astype(np.float32)
54+
x_cuda = torch.from_numpy(x_numpy).cuda()
55+
56+
trace_cpu = dtw_cpu(x_numpy)
57+
trace_cuda = dtw_cuda(x_cuda)
58+
59+
assert np.allclose(trace_cpu, trace_cuda)
60+
61+
62+
@pytest.mark.parametrize("shape", shapes)
63+
def test_median_filter(shape):
64+
x = torch.randn(*shape)
65+
66+
for filter_width in [3, 5, 7, 13]:
67+
filtered = median_filter(x, filter_width)
68+
69+
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
70+
pad_width = filter_width // 2
71+
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
72+
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
73+
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
74+
75+
assert np.allclose(filtered, scipy_filtered)
76+
77+
78+
@pytest.mark.requires_cuda
79+
@pytest.mark.parametrize("shape", shapes)
80+
def test_median_filter_equivalence(shape):
81+
x = torch.randn(*shape)
82+
83+
for filter_width in [3, 5, 7, 13]:
84+
filtered_cpu = median_filter(x, filter_width)
85+
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
86+
87+
assert np.allclose(filtered_cpu, filtered_gpu)

tests/test_transcribe.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,22 @@ def test_transcribe(model_name: str):
1313
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
1414

1515
language = "en" if model_name.endswith(".en") else None
16-
result = model.transcribe(audio_path, language=language, temperature=0.0)
16+
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
1717
assert result["language"] == "en"
1818

1919
transcription = result["text"].lower()
2020
assert "my fellow americans" in transcription
2121
assert "your country" in transcription
2222
assert "do for you" in transcription
23+
24+
timing_checked = False
25+
for segment in result["segments"]:
26+
for timing in segment["words"]:
27+
assert timing["start"] < timing["end"]
28+
if timing["word"].strip(" ,") == "Americans":
29+
assert timing["start"] <= 1.8
30+
assert timing["end"] >= 1.8
31+
print(timing)
32+
timing_checked = True
33+
34+
assert timing_checked

whisper/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@
2929
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
3030
}
3131

32+
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
33+
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
34+
_ALIGNMENT_HEADS = {
35+
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
36+
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
37+
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
38+
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
39+
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
40+
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
41+
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
42+
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
43+
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
44+
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
45+
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
46+
}
47+
48+
3249

3350
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
3451
os.makedirs(root, exist_ok=True)
@@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
106123

107124
if name in _MODELS:
108125
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
126+
alignment_heads = _ALIGNMENT_HEADS[name]
109127
elif os.path.isfile(name):
110128
checkpoint_file = open(name, "rb").read() if in_memory else name
129+
alignment_heads = None
111130
else:
112131
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
113132

@@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
119138
model = Whisper(dims)
120139
model.load_state_dict(checkpoint["model_state_dict"])
121140

141+
if alignment_heads is not None:
142+
model.set_alignment_heads(alignment_heads)
143+
122144
return model.to(device)

whisper/audio.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
1919
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
2020

21+
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
22+
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
23+
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)
24+
2125

2226
def load_audio(file: str, sr: int = SAMPLE_RATE):
2327
"""

whisper/model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import base64
2+
import gzip
13
from dataclasses import dataclass
24
from typing import Dict
35
from typing import Iterable, Optional
@@ -8,8 +10,8 @@
810
from torch import Tensor
911
from torch import nn
1012

11-
from .transcribe import transcribe as transcribe_function
1213
from .decoding import detect_language as detect_language_function, decode as decode_function
14+
from .transcribe import transcribe as transcribe_function
1315

1416

1517
@dataclass
@@ -213,6 +215,15 @@ def __init__(self, dims: ModelDimensions):
213215
self.dims.n_text_head,
214216
self.dims.n_text_layer,
215217
)
218+
# use the last half layers for alignment by default; see `set_alignment_heads()` below
219+
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
220+
all_heads[self.dims.n_text_layer // 2:] = True
221+
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
222+
223+
def set_alignment_heads(self, dump: bytes):
224+
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
225+
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
226+
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
216227

217228
def embed_audio(self, mel: torch.Tensor):
218229
return self.encoder(mel)

0 commit comments

Comments
 (0)