Skip to content

Commit 7a3114b

Browse files
jongwookilanit1997
authored andcommitted
drop python 3.7 support (openai#889)
1 parent 00cc3cc commit 7a3114b

File tree

3 files changed

+33
-49
lines changed

3 files changed

+33
-49
lines changed

whisper/decoding.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,10 @@ def __init__(self, temperature: float, eot: int):
256256
self.eot = eot
257257

258258
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
259-
temperature = self.temperature
260-
if temperature == 0:
259+
if self.temperature == 0:
261260
next_tokens = logits.argmax(dim=-1)
262261
else:
263-
next_tokens = Categorical(logits=logits / temperature).sample()
262+
next_tokens = Categorical(logits=logits / self.temperature).sample()
264263

265264
logprobs = F.log_softmax(logits.float(), dim=-1)
266265
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
@@ -515,10 +514,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
515514

516515
def _get_initial_tokens(self) -> Tuple[int]:
517516
tokens = list(self.sot_sequence)
518-
prefix = self.options.prefix
519-
prompt = self.options.prompt
520517

521-
if prefix:
518+
if prefix := self.options.prefix:
522519
prefix_tokens = (
523520
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
524521
)
@@ -527,7 +524,7 @@ def _get_initial_tokens(self) -> Tuple[int]:
527524
prefix_tokens = prefix_tokens[-max_prefix_len:]
528525
tokens = tokens + prefix_tokens
529526

530-
if prompt:
527+
if prompt := self.options.prompt:
531528
prompt_tokens = (
532529
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
533530
)
@@ -721,13 +718,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
721718
result: Union[DecodingResult, List[DecodingResult]]
722719
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
723720
"""
724-
single = mel.ndim == 2
725-
if single:
721+
if single := mel.ndim == 2:
726722
mel = mel.unsqueeze(0)
727723

728724
result = DecodingTask(model, options).run(mel)
729-
730-
if single:
731-
result = result[0]
732725

733-
return result
726+
return result[0] if single else result

whisper/tokenizer.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from dataclasses import dataclass
3-
from functools import lru_cache
3+
from functools import lru_cache, cached_property
44
from typing import List, Optional, Tuple, Union
55

66
import numpy as np
@@ -156,43 +156,35 @@ def decode_with_timestamps(self, tokens) -> str:
156156
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
157157
return "".join(outputs)
158158

159-
@property
160-
@lru_cache()
159+
@cached_property
161160
def eot(self) -> int:
162161
return self.tokenizer.eos_token_id
163162

164-
@property
165-
@lru_cache()
163+
@cached_property
166164
def sot(self) -> int:
167165
return self._get_single_token_id("<|startoftranscript|>")
168166

169-
@property
170-
@lru_cache()
167+
@cached_property
171168
def sot_lm(self) -> int:
172169
return self._get_single_token_id("<|startoflm|>")
173170

174-
@property
175-
@lru_cache()
171+
@cached_property
176172
def sot_prev(self) -> int:
177173
return self._get_single_token_id("<|startofprev|>")
178174

179-
@property
180-
@lru_cache()
175+
@cached_property
181176
def no_speech(self) -> int:
182177
return self._get_single_token_id("<|nospeech|>")
183178

184-
@property
185-
@lru_cache()
179+
@cached_property
186180
def no_timestamps(self) -> int:
187181
return self._get_single_token_id("<|notimestamps|>")
188182

189-
@property
190-
@lru_cache()
183+
@cached_property
191184
def timestamp_begin(self) -> int:
192185
return self.tokenizer.all_special_ids[-1] + 1
193186

194-
@property
195-
@lru_cache()
187+
@cached_property
196188
def language_token(self) -> int:
197189
"""Returns the token id corresponding to the value of the `language` field"""
198190
if self.language is None:
@@ -210,8 +202,7 @@ def language_token(self) -> int:
210202

211203
raise KeyError(f"Language {self.language} not found in tokenizer.")
212204

213-
@property
214-
@lru_cache()
205+
@cached_property
215206
def all_language_tokens(self) -> Tuple[int]:
216207
result = []
217208
for token, token_id in zip(
@@ -222,18 +213,15 @@ def all_language_tokens(self) -> Tuple[int]:
222213
result.append(token_id)
223214
return tuple(result)
224215

225-
@property
226-
@lru_cache()
216+
@cached_property
227217
def all_language_codes(self) -> Tuple[str]:
228218
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
229219

230-
@property
231-
@lru_cache()
220+
@cached_property
232221
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
233222
return tuple(list(self.sot_sequence) + [self.no_timestamps])
234223

235-
@property
236-
@lru_cache()
224+
@cached_property
237225
def non_speech_tokens(self) -> Tuple[int]:
238226
"""
239227
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech

whisper/transcribe.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def transcribe(
2626
logprob_threshold: Optional[float] = -1.0,
2727
no_speech_threshold: Optional[float] = 0.6,
2828
condition_on_previous_text: bool = True,
29+
initial_prompt: Optional[str] = None,
2930
**decode_options,
3031
):
3132
"""
@@ -138,10 +139,11 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
138139
all_segments = []
139140
prompt_reset_since = 0
140141

141-
initial_prompt = decode_options.pop("initial_prompt", None) or []
142-
if initial_prompt:
143-
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
144-
all_tokens.extend(initial_prompt)
142+
if initial_prompt is not None:
143+
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
144+
all_tokens.extend(initial_prompt_tokens)
145+
else:
146+
initial_prompt_tokens = []
145147

146148
def add_segment(
147149
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult, encoder_embeddings, decoder_embeddings
@@ -263,7 +265,11 @@ def add_segment(
263265
pbar.update(min(num_frames, seek) - previous_seek_value)
264266
previous_seek_value = seek
265267

266-
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
268+
return dict(
269+
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
270+
segments=all_segments,
271+
language=language
272+
)
267273

268274

269275
def cli():
@@ -312,21 +318,18 @@ def cli():
312318
args["language"] = "en"
313319

314320
temperature = args.pop("temperature")
315-
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
316-
if temperature_increment_on_fallback is not None:
317-
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
321+
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
322+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
318323
else:
319324
temperature = [temperature]
320325

321-
threads = args.pop("threads")
322-
if threads > 0:
326+
if (threads := args.pop("threads")) > 0:
323327
torch.set_num_threads(threads)
324328

325329
from . import load_model
326330
model = load_model(model_name, device=device, download_root=model_dir)
327331

328332
writer = get_writer(output_format, output_dir)
329-
330333
for audio_path in args.pop("audio"):
331334
result = transcribe(model, audio_path, temperature=temperature, **args)
332335
writer(result, audio_path)

0 commit comments

Comments
 (0)