Skip to content

Commit 0fa2e31

Browse files
Aaryan369jongwook
authored andcommitted
Added --output_format option (openai#333)
* Added --output option --output option will help select the output files that will be generated. Corrected the logic, which wrongly shows progress bar when verbose is set to False * Changed output_files variable * Changed back the tqdm verbose * refactor output format handling Co-authored-by: Jong Wook Kim <jongwook@openai.com> Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
1 parent 797d103 commit 0fa2e31

File tree

2 files changed

+89
-57
lines changed

2 files changed

+89
-57
lines changed

whisper/transcribe.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
1212
from .decoding import DecodingOptions, DecodingResult
1313
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
14-
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
14+
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, get_writer
1515

1616
if TYPE_CHECKING:
1717
from .model import Whisper
@@ -260,6 +260,7 @@ def cli():
260260
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
261261
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
262262
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
263+
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
263264
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
264265

265266
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@@ -286,6 +287,7 @@ def cli():
286287
model_name: str = args.pop("model")
287288
model_dir: str = args.pop("model_dir")
288289
output_dir: str = args.pop("output_dir")
290+
output_format: str = args.pop("output_format")
289291
device: str = args.pop("device")
290292
os.makedirs(output_dir, exist_ok=True)
291293

@@ -308,22 +310,11 @@ def cli():
308310
from . import load_model
309311
model = load_model(model_name, device=device, download_root=model_dir)
310312

313+
writer = get_writer(output_format, output_dir)
314+
311315
for audio_path in args.pop("audio"):
312316
result = transcribe(model, audio_path, temperature=temperature, **args)
313-
314-
audio_basename = os.path.basename(audio_path)
315-
316-
# save TXT
317-
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
318-
write_txt(result["segments"], file=txt)
319-
320-
# save VTT
321-
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
322-
write_vtt(result["segments"], file=vtt)
323-
324-
# save SRT
325-
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
326-
write_srt(result["segments"], file=srt)
317+
writer(result, audio_path)
327318

328319

329320
if __name__ == '__main__':

whisper/utils.py

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import json
2+
import os
13
import zlib
2-
from typing import Iterator, TextIO
4+
from typing import Callable, TextIO
35

46

57
def exact_div(x, y):
@@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
4547
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
4648

4749

48-
def write_txt(transcript: Iterator[dict], file: TextIO):
49-
for segment in transcript:
50-
print(segment['text'].strip(), file=file, flush=True)
51-
52-
53-
def write_vtt(transcript: Iterator[dict], file: TextIO):
54-
print("WEBVTT\n", file=file)
55-
for segment in transcript:
56-
print(
57-
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
58-
f"{segment['text'].strip().replace('-->', '->')}\n",
59-
file=file,
60-
flush=True,
61-
)
62-
63-
64-
def write_srt(transcript: Iterator[dict], file: TextIO):
65-
"""
66-
Write a transcript to a file in SRT format.
67-
68-
Example usage:
69-
from pathlib import Path
70-
from whisper.utils import write_srt
71-
72-
result = transcribe(model, audio_path, temperature=temperature, **args)
73-
74-
# save SRT
75-
audio_basename = Path(audio_path).stem
76-
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
77-
write_srt(result["segments"], file=srt)
78-
"""
79-
for i, segment in enumerate(transcript, start=1):
80-
# write srt lines
81-
print(
82-
f"{i}\n"
83-
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
84-
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
85-
f"{segment['text'].strip().replace('-->', '->')}\n",
86-
file=file,
87-
flush=True,
88-
)
50+
class ResultWriter:
51+
extension: str
52+
53+
def __init__(self, output_dir: str):
54+
self.output_dir = output_dir
55+
56+
def __call__(self, result: dict, audio_path: str):
57+
audio_basename = os.path.basename(audio_path)
58+
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
59+
60+
with open(output_path, "w", encoding="utf-8") as f:
61+
self.write_result(result, file=f)
62+
63+
def write_result(self, result: dict, file: TextIO):
64+
raise NotImplementedError
65+
66+
67+
class WriteTXT(ResultWriter):
68+
extension: str = "txt"
69+
70+
def write_result(self, result: dict, file: TextIO):
71+
for segment in result["segments"]:
72+
print(segment['text'].strip(), file=file, flush=True)
73+
74+
75+
class WriteVTT(ResultWriter):
76+
extension: str = "vtt"
77+
78+
def write_result(self, result: dict, file: TextIO):
79+
print("WEBVTT\n", file=file)
80+
for segment in result["segments"]:
81+
print(
82+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
83+
f"{segment['text'].strip().replace('-->', '->')}\n",
84+
file=file,
85+
flush=True,
86+
)
87+
88+
89+
class WriteSRT(ResultWriter):
90+
extension: str = "srt"
91+
92+
def write_result(self, result: dict, file: TextIO):
93+
for i, segment in enumerate(result["segments"], start=1):
94+
# write srt lines
95+
print(
96+
f"{i}\n"
97+
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
98+
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
99+
f"{segment['text'].strip().replace('-->', '->')}\n",
100+
file=file,
101+
flush=True,
102+
)
103+
104+
105+
class WriteJSON(ResultWriter):
106+
extension: str = "json"
107+
108+
def write_result(self, result: dict, file: TextIO):
109+
json.dump(result, file)
110+
111+
112+
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
113+
writers = {
114+
"txt": WriteTXT,
115+
"vtt": WriteVTT,
116+
"srt": WriteSRT,
117+
"json": WriteJSON,
118+
}
119+
120+
if output_format == "all":
121+
all_writers = [writer(output_dir) for writer in writers.values()]
122+
123+
def write_all(result: dict, file: TextIO):
124+
for writer in all_writers:
125+
writer(result, file)
126+
127+
return write_all
128+
129+
return writers[output_format](output_dir)

0 commit comments

Comments
 (0)