|
| 1 | +import json |
| 2 | +import os |
1 | 3 | import zlib |
2 | | -from typing import Iterator, TextIO |
| 4 | +from typing import Callable, TextIO |
3 | 5 |
|
4 | 6 |
|
5 | 7 | def exact_div(x, y): |
@@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal |
45 | 47 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" |
46 | 48 |
|
47 | 49 |
|
48 | | -def write_txt(transcript: Iterator[dict], file: TextIO): |
49 | | - for segment in transcript: |
50 | | - print(segment['text'].strip(), file=file, flush=True) |
51 | | - |
52 | | - |
53 | | -def write_vtt(transcript: Iterator[dict], file: TextIO): |
54 | | - print("WEBVTT\n", file=file) |
55 | | - for segment in transcript: |
56 | | - print( |
57 | | - f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" |
58 | | - f"{segment['text'].strip().replace('-->', '->')}\n", |
59 | | - file=file, |
60 | | - flush=True, |
61 | | - ) |
62 | | - |
63 | | - |
64 | | -def write_srt(transcript: Iterator[dict], file: TextIO): |
65 | | - """ |
66 | | - Write a transcript to a file in SRT format. |
67 | | -
|
68 | | - Example usage: |
69 | | - from pathlib import Path |
70 | | - from whisper.utils import write_srt |
71 | | -
|
72 | | - result = transcribe(model, audio_path, temperature=temperature, **args) |
73 | | -
|
74 | | - # save SRT |
75 | | - audio_basename = Path(audio_path).stem |
76 | | - with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: |
77 | | - write_srt(result["segments"], file=srt) |
78 | | - """ |
79 | | - for i, segment in enumerate(transcript, start=1): |
80 | | - # write srt lines |
81 | | - print( |
82 | | - f"{i}\n" |
83 | | - f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " |
84 | | - f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" |
85 | | - f"{segment['text'].strip().replace('-->', '->')}\n", |
86 | | - file=file, |
87 | | - flush=True, |
88 | | - ) |
| 50 | +class ResultWriter: |
| 51 | + extension: str |
| 52 | + |
| 53 | + def __init__(self, output_dir: str): |
| 54 | + self.output_dir = output_dir |
| 55 | + |
| 56 | + def __call__(self, result: dict, audio_path: str): |
| 57 | + audio_basename = os.path.basename(audio_path) |
| 58 | + output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) |
| 59 | + |
| 60 | + with open(output_path, "w", encoding="utf-8") as f: |
| 61 | + self.write_result(result, file=f) |
| 62 | + |
| 63 | + def write_result(self, result: dict, file: TextIO): |
| 64 | + raise NotImplementedError |
| 65 | + |
| 66 | + |
| 67 | +class WriteTXT(ResultWriter): |
| 68 | + extension: str = "txt" |
| 69 | + |
| 70 | + def write_result(self, result: dict, file: TextIO): |
| 71 | + for segment in result["segments"]: |
| 72 | + print(segment['text'].strip(), file=file, flush=True) |
| 73 | + |
| 74 | + |
| 75 | +class WriteVTT(ResultWriter): |
| 76 | + extension: str = "vtt" |
| 77 | + |
| 78 | + def write_result(self, result: dict, file: TextIO): |
| 79 | + print("WEBVTT\n", file=file) |
| 80 | + for segment in result["segments"]: |
| 81 | + print( |
| 82 | + f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" |
| 83 | + f"{segment['text'].strip().replace('-->', '->')}\n", |
| 84 | + file=file, |
| 85 | + flush=True, |
| 86 | + ) |
| 87 | + |
| 88 | + |
| 89 | +class WriteSRT(ResultWriter): |
| 90 | + extension: str = "srt" |
| 91 | + |
| 92 | + def write_result(self, result: dict, file: TextIO): |
| 93 | + for i, segment in enumerate(result["segments"], start=1): |
| 94 | + # write srt lines |
| 95 | + print( |
| 96 | + f"{i}\n" |
| 97 | + f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " |
| 98 | + f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" |
| 99 | + f"{segment['text'].strip().replace('-->', '->')}\n", |
| 100 | + file=file, |
| 101 | + flush=True, |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +class WriteJSON(ResultWriter): |
| 106 | + extension: str = "json" |
| 107 | + |
| 108 | + def write_result(self, result: dict, file: TextIO): |
| 109 | + json.dump(result, file) |
| 110 | + |
| 111 | + |
| 112 | +def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: |
| 113 | + writers = { |
| 114 | + "txt": WriteTXT, |
| 115 | + "vtt": WriteVTT, |
| 116 | + "srt": WriteSRT, |
| 117 | + "json": WriteJSON, |
| 118 | + } |
| 119 | + |
| 120 | + if output_format == "all": |
| 121 | + all_writers = [writer(output_dir) for writer in writers.values()] |
| 122 | + |
| 123 | + def write_all(result: dict, file: TextIO): |
| 124 | + for writer in all_writers: |
| 125 | + writer(result, file) |
| 126 | + |
| 127 | + return write_all |
| 128 | + |
| 129 | + return writers[output_format](output_dir) |
0 commit comments