|
@@ -1,5 +1,7 @@
|
|
|
+import json
|
|
|
+import os
|
|
|
import zlib
|
|
|
-from typing import Iterator, TextIO
|
|
|
+from typing import Callable, TextIO
|
|
|
|
|
|
|
|
|
def exact_div(x, y):
|
|
@@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
|
|
|
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
|
|
|
|
|
|
|
|
-def write_txt(transcript: Iterator[dict], file: TextIO):
|
|
|
- for segment in transcript:
|
|
|
- print(segment['text'].strip(), file=file, flush=True)
|
|
|
-
|
|
|
-
|
|
|
-def write_vtt(transcript: Iterator[dict], file: TextIO):
|
|
|
- print("WEBVTT\n", file=file)
|
|
|
- for segment in transcript:
|
|
|
- print(
|
|
|
- f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
|
|
- f"{segment['text'].strip().replace('-->', '->')}\n",
|
|
|
- file=file,
|
|
|
- flush=True,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def write_srt(transcript: Iterator[dict], file: TextIO):
|
|
|
- """
|
|
|
- Write a transcript to a file in SRT format.
|
|
|
-
|
|
|
- Example usage:
|
|
|
- from pathlib import Path
|
|
|
- from whisper.utils import write_srt
|
|
|
-
|
|
|
- result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
|
-
|
|
|
- # save SRT
|
|
|
- audio_basename = Path(audio_path).stem
|
|
|
- with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
|
|
- write_srt(result["segments"], file=srt)
|
|
|
- """
|
|
|
- for i, segment in enumerate(transcript, start=1):
|
|
|
- # write srt lines
|
|
|
- print(
|
|
|
- f"{i}\n"
|
|
|
- f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
|
|
- f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
|
|
- f"{segment['text'].strip().replace('-->', '->')}\n",
|
|
|
- file=file,
|
|
|
- flush=True,
|
|
|
- )
|
|
|
+class ResultWriter:
|
|
|
+ extension: str
|
|
|
+
|
|
|
+ def __init__(self, output_dir: str):
|
|
|
+ self.output_dir = output_dir
|
|
|
+
|
|
|
+ def __call__(self, result: dict, audio_path: str):
|
|
|
+ audio_basename = os.path.basename(audio_path)
|
|
|
+ output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
|
|
|
+
|
|
|
+ with open(output_path, "w", encoding="utf-8") as f:
|
|
|
+ self.write_result(result, file=f)
|
|
|
+
|
|
|
+ def write_result(self, result: dict, file: TextIO):
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+
|
|
|
+class WriteTXT(ResultWriter):
|
|
|
+ extension: str = "txt"
|
|
|
+
|
|
|
+ def write_result(self, result: dict, file: TextIO):
|
|
|
+ for segment in result["segments"]:
|
|
|
+ print(segment['text'].strip(), file=file, flush=True)
|
|
|
+
|
|
|
+
|
|
|
+class WriteVTT(ResultWriter):
|
|
|
+ extension: str = "vtt"
|
|
|
+
|
|
|
+ def write_result(self, result: dict, file: TextIO):
|
|
|
+ print("WEBVTT\n", file=file)
|
|
|
+ for segment in result["segments"]:
|
|
|
+ print(
|
|
|
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
|
|
+ f"{segment['text'].strip().replace('-->', '->')}\n",
|
|
|
+ file=file,
|
|
|
+ flush=True,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class WriteSRT(ResultWriter):
|
|
|
+ extension: str = "srt"
|
|
|
+
|
|
|
+ def write_result(self, result: dict, file: TextIO):
|
|
|
+ for i, segment in enumerate(result["segments"], start=1):
|
|
|
+ # write srt lines
|
|
|
+ print(
|
|
|
+ f"{i}\n"
|
|
|
+ f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
|
|
+ f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
|
|
+ f"{segment['text'].strip().replace('-->', '->')}\n",
|
|
|
+ file=file,
|
|
|
+ flush=True,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class WriteJSON(ResultWriter):
|
|
|
+ extension: str = "json"
|
|
|
+
|
|
|
+ def write_result(self, result: dict, file: TextIO):
|
|
|
+ json.dump(result, file)
|
|
|
+
|
|
|
+
|
|
|
+def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
|
|
+ writers = {
|
|
|
+ "txt": WriteTXT,
|
|
|
+ "vtt": WriteVTT,
|
|
|
+ "srt": WriteSRT,
|
|
|
+ "json": WriteJSON,
|
|
|
+ }
|
|
|
+
|
|
|
+ if output_format == "all":
|
|
|
+ all_writers = [writer(output_dir) for writer in writers.values()]
|
|
|
+
|
|
|
+ def write_all(result: dict, file: TextIO):
|
|
|
+ for writer in all_writers:
|
|
|
+ writer(result, file)
|
|
|
+
|
|
|
+ return write_all
|
|
|
+
|
|
|
+ return writers[output_format](output_dir)
|