Skip to content

Commit 627ea97

Browse files
jongwookzackees
authored andcommitted
apply formatting with black (openai#1038)
* applying black (with the default 88-column limit) * add flake8 * add isort * fix isort
1 parent ff3ab06 commit 627ea97

File tree

21 files changed

+533
-227
lines changed

21 files changed

+533
-227
lines changed

.flake8

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[flake8]
2+
per-file-ignores =
3+
*/__init__.py: F401
4+

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@ jobs:
2222
- uses: actions/checkout@v2
2323
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
2424
- run: pip install .["dev"]
25+
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
26+
- run: isort --check --diff .
27+
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
2528
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[tool.black]
2+
3+
[tool.isort]
4+
profile = "black"
5+
include_trailing_comma = true
6+
line_length = 88
7+
multi_line_output = 3
8+

setup.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33

44
import pkg_resources
5-
from setuptools import setup, find_packages
5+
from setuptools import find_packages, setup
66

77
# The directory containing this file
88
HERE = os.path.dirname(__file__)
@@ -17,7 +17,10 @@
1717
try:
1818
import re
1919
import subprocess
20-
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
20+
21+
version_line = (
22+
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
23+
)
2124
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
2225
if (int(major), int(minor)) < (11, 4):
2326
# the last version supporting CUDA < 11.4
@@ -37,7 +40,8 @@
3740
url="https://github.com/openai/whisper",
3841
license="MIT",
3942
packages=find_packages(exclude=["tests*"]),
40-
install_requires=requirements + [
43+
install_requires=requirements
44+
+ [
4145
str(r)
4246
for r in pkg_resources.parse_requirements(
4347
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -47,5 +51,5 @@
4751
"console_scripts": ["whisper=whisper.transcribe:cli"],
4852
},
4953
include_package_data=True,
50-
extras_require={"dev": ["pytest", "scipy"]},
54+
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
5155
)

tests/test_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
5+
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
66

77

88
def test_audio():

tests/test_normalizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import pytest
22

33
from whisper.normalizers import EnglishTextNormalizer
4-
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
4+
from whisper.normalizers.english import (
5+
EnglishNumberNormalizer,
6+
EnglishSpellingNormalizer,
7+
)
58

69

710
@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])

tests/test_timing.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
import pytest
21
import numpy as np
2+
import pytest
33
import scipy.ndimage
44
import torch
55

66
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
77

8-
98
sizes = [
10-
(10, 20), (32, 16), (123, 1500), (234, 189),
9+
(10, 20),
10+
(32, 16),
11+
(123, 1500),
12+
(234, 189),
1113
]
1214
shapes = [
13-
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
15+
(10,),
16+
(1, 15),
17+
(4, 5, 345),
18+
(6, 12, 240, 512),
1419
]
1520

1621

@@ -68,8 +73,12 @@ def test_median_filter(shape):
6873

6974
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
7075
pad_width = filter_width // 2
71-
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
72-
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
76+
padded_x = np.pad(
77+
x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
78+
)
79+
scipy_filtered = scipy.ndimage.median_filter(
80+
padded_x, [1] * (x.ndim - 1) + [filter_width]
81+
)
7382
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
7483

7584
assert np.allclose(filtered, scipy_filtered)

tests/test_transcribe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ def test_transcribe(model_name: str):
1313
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
1414

1515
language = "en" if model_name.endswith(".en") else None
16-
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
16+
result = model.transcribe(
17+
audio_path, language=language, temperature=0.0, word_timestamps=True
18+
)
1719
assert result["language"] == "en"
1820

1921
transcription = result["text"].lower()

whisper/__init__.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010

1111
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
1212
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13-
from .model import Whisper, ModelDimensions
13+
from .model import ModelDimensions, Whisper
1414
from .transcribe import transcribe
1515
from .version import __version__
1616

17-
1817
_MODELS = {
1918
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
2019
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@@ -41,12 +40,11 @@
4140
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
4241
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
4342
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
44-
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
45-
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
43+
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
44+
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
4645
}
4746

4847

49-
5048
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
5149
os.makedirs(root, exist_ok=True)
5250

@@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
6260
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
6361
return model_bytes if in_memory else download_target
6462
else:
65-
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
63+
warnings.warn(
64+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
65+
)
6666

6767
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
68-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
68+
with tqdm(
69+
total=int(source.info().get("Content-Length")),
70+
ncols=80,
71+
unit="iB",
72+
unit_scale=True,
73+
unit_divisor=1024,
74+
) as loop:
6975
while True:
7076
buffer = source.read(8192)
7177
if not buffer:
@@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
7682

7783
model_bytes = open(download_target, "rb").read()
7884
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
79-
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
85+
raise RuntimeError(
86+
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
87+
)
8088

8189
return model_bytes if in_memory else download_target
8290

@@ -86,7 +94,12 @@ def available_models() -> List[str]:
8694
return list(_MODELS.keys())
8795

8896

89-
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
97+
def load_model(
98+
name: str,
99+
device: Optional[Union[str, torch.device]] = None,
100+
download_root: str = None,
101+
in_memory: bool = False,
102+
) -> Whisper:
90103
"""
91104
Load a Whisper ASR model
92105
@@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
111124
if device is None:
112125
device = "cuda" if torch.cuda.is_available() else "cpu"
113126
if download_root is None:
114-
download_root = os.path.join(
115-
os.getenv(
116-
"XDG_CACHE_HOME",
117-
os.path.join(
118-
os.path.expanduser("~"), ".cache"
119-
)
120-
),
121-
"whisper"
122-
)
127+
default = os.path.join(os.path.expanduser("~"), ".cache")
128+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
123129

124130
if name in _MODELS:
125131
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
@@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
128134
checkpoint_file = open(name, "rb").read() if in_memory else name
129135
alignment_heads = None
130136
else:
131-
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
137+
raise RuntimeError(
138+
f"Model {name} not found; available models = {available_models()}"
139+
)
132140

133-
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
141+
with (
142+
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
143+
) as fp:
134144
checkpoint = torch.load(fp, map_location=device)
135145
del checkpoint_file
136146

whisper/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .transcribe import cli
22

3-
43
cli()

0 commit comments

Comments
 (0)