123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import json
- import os
- import sys
- import zlib
- from typing import Callable, TextIO
- system_encoding = sys.getdefaultencoding()
- if system_encoding != "utf-8":
- def make_safe(string):
-
-
- return string.encode(system_encoding, errors="replace").decode(system_encoding)
- else:
- def make_safe(string):
-
- return string
- def exact_div(x, y):
- assert x % y == 0
- return x // y
- def str2bool(string):
- str2val = {"True": True, "False": False}
- if string in str2val:
- return str2val[string]
- else:
- raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
- def optional_int(string):
- return None if string == "None" else int(string)
- def optional_float(string):
- return None if string == "None" else float(string)
- def compression_ratio(text) -> float:
- text_bytes = text.encode("utf-8")
- return len(text_bytes) / len(zlib.compress(text_bytes))
- def format_timestamp(
- seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
- ):
- assert seconds >= 0, "non-negative timestamp expected"
- milliseconds = round(seconds * 1000.0)
- hours = milliseconds // 3_600_000
- milliseconds -= hours * 3_600_000
- minutes = milliseconds // 60_000
- milliseconds -= minutes * 60_000
- seconds = milliseconds // 1_000
- milliseconds -= seconds * 1_000
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
- return (
- f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
- )
- class ResultWriter:
- extension: str
- def __init__(self, output_dir: str):
- self.output_dir = output_dir
- def __call__(self, result: dict, audio_path: str):
- audio_basename = os.path.basename(audio_path)
- audio_basename = os.path.splitext(audio_basename)[0]
- output_path = os.path.join(
- self.output_dir, audio_basename + "." + self.extension
- )
- with open(output_path, "w", encoding="utf-8") as f:
- self.write_result(result, file=f)
- def write_result(self, result: dict, file: TextIO):
- raise NotImplementedError
- class WriteTXT(ResultWriter):
- extension: str = "txt"
- def write_result(self, result: dict, file: TextIO):
- for segment in result["segments"]:
- print(segment["text"].strip(), file=file, flush=True)
- class SubtitlesWriter(ResultWriter):
- always_include_hours: bool
- decimal_marker: str
- def iterate_result(self, result: dict):
- for segment in result["segments"]:
- segment_start = self.format_timestamp(segment["start"])
- segment_end = self.format_timestamp(segment["end"])
- segment_text = segment["text"].strip().replace("-->", "->")
- if word_timings := segment.get("words", None):
- all_words = [timing["word"] for timing in word_timings]
- all_words[0] = all_words[0].strip()
- last = segment_start
- for i, this_word in enumerate(word_timings):
- start = self.format_timestamp(this_word["start"])
- end = self.format_timestamp(this_word["end"])
- if last != start:
- yield last, start, segment_text
- yield start, end, "".join(
- [
- f"<u>{word}</u>" if j == i else word
- for j, word in enumerate(all_words)
- ]
- )
- last = end
- if last != segment_end:
- yield last, segment_end, segment_text
- else:
- yield segment_start, segment_end, segment_text
- def format_timestamp(self, seconds: float):
- return format_timestamp(
- seconds=seconds,
- always_include_hours=self.always_include_hours,
- decimal_marker=self.decimal_marker,
- )
- class WriteVTT(SubtitlesWriter):
- extension: str = "vtt"
- always_include_hours: bool = False
- decimal_marker: str = "."
- def write_result(self, result: dict, file: TextIO):
- print("WEBVTT\n", file=file)
- for start, end, text in self.iterate_result(result):
- print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
- class WriteSRT(SubtitlesWriter):
- extension: str = "srt"
- always_include_hours: bool = True
- decimal_marker: str = ","
- def write_result(self, result: dict, file: TextIO):
- for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
- print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
- class WriteTSV(ResultWriter):
- """
- Write a transcript to a file in TSV (tab-separated values) format containing lines like:
- <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
- Using integer milliseconds as start and end times means there's no chance of interference from
- an environment setting a language encoding that causes the decimal in a floating point number
- to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
- """
- extension: str = "tsv"
- def write_result(self, result: dict, file: TextIO):
- print("start", "end", "text", sep="\t", file=file)
- for segment in result["segments"]:
- print(round(1000 * segment["start"]), file=file, end="\t")
- print(round(1000 * segment["end"]), file=file, end="\t")
- print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
- class WriteJSON(ResultWriter):
- extension: str = "json"
- def write_result(self, result: dict, file: TextIO):
- json.dump(result, file)
- def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
- writers = {
- "txt": WriteTXT,
- "vtt": WriteVTT,
- "srt": WriteSRT,
- "tsv": WriteTSV,
- "json": WriteJSON,
- }
- if output_format == "all":
- all_writers = [writer(output_dir) for writer in writers.values()]
- def write_all(result: dict, file: TextIO):
- for writer in all_writers:
- writer(result, file)
- return write_all
- return writers[output_format](output_dir)
|