Skip to content
2 changes: 1 addition & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")

parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
Expand Down
21 changes: 21 additions & 0 deletions whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,25 @@ def write_result(self, result: dict, file: TextIO):
)


class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>

Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"

def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment['start']), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)


class WriteJSON(ResultWriter):
extension: str = "json"

Expand All @@ -114,6 +133,7 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}

Expand All @@ -127,3 +147,4 @@ def write_all(result: dict, file: TextIO):
return write_all

return writers[output_format](output_dir)