1010
1111from .audio import load_audio , log_mel_spectrogram , pad_or_trim
1212from .decoding import DecodingOptions , DecodingResult , decode , detect_language
13- from .model import Whisper , ModelDimensions
13+ from .model import ModelDimensions , Whisper
1414from .transcribe import transcribe
1515from .version import __version__
1616
17-
1817_MODELS = {
1918 "tiny.en" : "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt" ,
2019 "tiny" : "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt" ,
4140 "medium.en" : b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00" ,
4241 "medium" : b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9" ,
4342 "large-v1" : b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj" ,
44- "large-v2" : b' ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj' ,
45- "large" : b' ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj' ,
43+ "large-v2" : b" ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj" ,
44+ "large" : b" ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj" ,
4645}
4746
4847
49-
5048def _download (url : str , root : str , in_memory : bool ) -> Union [bytes , str ]:
5149 os .makedirs (root , exist_ok = True )
5250
@@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
6260 if hashlib .sha256 (model_bytes ).hexdigest () == expected_sha256 :
6361 return model_bytes if in_memory else download_target
6462 else :
65- warnings .warn (f"{ download_target } exists, but the SHA256 checksum does not match; re-downloading the file" )
63+ warnings .warn (
64+ f"{ download_target } exists, but the SHA256 checksum does not match; re-downloading the file"
65+ )
6666
6767 with urllib .request .urlopen (url ) as source , open (download_target , "wb" ) as output :
68- with tqdm (total = int (source .info ().get ("Content-Length" )), ncols = 80 , unit = 'iB' , unit_scale = True , unit_divisor = 1024 ) as loop :
68+ with tqdm (
69+ total = int (source .info ().get ("Content-Length" )),
70+ ncols = 80 ,
71+ unit = "iB" ,
72+ unit_scale = True ,
73+ unit_divisor = 1024 ,
74+ ) as loop :
6975 while True :
7076 buffer = source .read (8192 )
7177 if not buffer :
@@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
7682
7783 model_bytes = open (download_target , "rb" ).read ()
7884 if hashlib .sha256 (model_bytes ).hexdigest () != expected_sha256 :
79- raise RuntimeError ("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." )
85+ raise RuntimeError (
86+ "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
87+ )
8088
8189 return model_bytes if in_memory else download_target
8290
@@ -86,7 +94,12 @@ def available_models() -> List[str]:
8694 return list (_MODELS .keys ())
8795
8896
89- def load_model (name : str , device : Optional [Union [str , torch .device ]] = None , download_root : str = None , in_memory : bool = False ) -> Whisper :
97+ def load_model (
98+ name : str ,
99+ device : Optional [Union [str , torch .device ]] = None ,
100+ download_root : str = None ,
101+ in_memory : bool = False ,
102+ ) -> Whisper :
90103 """
91104 Load a Whisper ASR model
92105
@@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
111124 if device is None :
112125 device = "cuda" if torch .cuda .is_available () else "cpu"
113126 if download_root is None :
114- download_root = os .path .join (
115- os .getenv (
116- "XDG_CACHE_HOME" ,
117- os .path .join (
118- os .path .expanduser ("~" ), ".cache"
119- )
120- ),
121- "whisper"
122- )
127+ default = os .path .join (os .path .expanduser ("~" ), ".cache" )
128+ download_root = os .path .join (os .getenv ("XDG_CACHE_HOME" , default ), "whisper" )
123129
124130 if name in _MODELS :
125131 checkpoint_file = _download (_MODELS [name ], download_root , in_memory )
@@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
128134 checkpoint_file = open (name , "rb" ).read () if in_memory else name
129135 alignment_heads = None
130136 else :
131- raise RuntimeError (f"Model { name } not found; available models = { available_models ()} " )
137+ raise RuntimeError (
138+ f"Model { name } not found; available models = { available_models ()} "
139+ )
132140
133- with (io .BytesIO (checkpoint_file ) if in_memory else open (checkpoint_file , "rb" )) as fp :
141+ with (
142+ io .BytesIO (checkpoint_file ) if in_memory else open (checkpoint_file , "rb" )
143+ ) as fp :
134144 checkpoint = torch .load (fp , map_location = device )
135145 del checkpoint_file
136146
0 commit comments