utils.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import json
  2. import os
  3. import sys
  4. import zlib
  5. from typing import Callable, TextIO
  6. system_encoding = sys.getdefaultencoding()
  7. if system_encoding != "utf-8":
  8. def make_safe(string):
  9. # replaces any character not representable using the system default encoding with an '?',
  10. # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
  11. return string.encode(system_encoding, errors="replace").decode(system_encoding)
  12. else:
  13. def make_safe(string):
  14. # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
  15. return string
  16. def exact_div(x, y):
  17. assert x % y == 0
  18. return x // y
  19. def str2bool(string):
  20. str2val = {"True": True, "False": False}
  21. if string in str2val:
  22. return str2val[string]
  23. else:
  24. raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
  25. def optional_int(string):
  26. return None if string == "None" else int(string)
  27. def optional_float(string):
  28. return None if string == "None" else float(string)
  29. def compression_ratio(text) -> float:
  30. text_bytes = text.encode("utf-8")
  31. return len(text_bytes) / len(zlib.compress(text_bytes))
  32. def format_timestamp(
  33. seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
  34. ):
  35. assert seconds >= 0, "non-negative timestamp expected"
  36. milliseconds = round(seconds * 1000.0)
  37. hours = milliseconds // 3_600_000
  38. milliseconds -= hours * 3_600_000
  39. minutes = milliseconds // 60_000
  40. milliseconds -= minutes * 60_000
  41. seconds = milliseconds // 1_000
  42. milliseconds -= seconds * 1_000
  43. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  44. return (
  45. f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  46. )
  47. class ResultWriter:
  48. extension: str
  49. def __init__(self, output_dir: str):
  50. self.output_dir = output_dir
  51. def __call__(self, result: dict, audio_path: str):
  52. audio_basename = os.path.basename(audio_path)
  53. audio_basename = os.path.splitext(audio_basename)[0]
  54. output_path = os.path.join(
  55. self.output_dir, audio_basename + "." + self.extension
  56. )
  57. with open(output_path, "w", encoding="utf-8") as f:
  58. self.write_result(result, file=f)
  59. def write_result(self, result: dict, file: TextIO):
  60. raise NotImplementedError
  61. class WriteTXT(ResultWriter):
  62. extension: str = "txt"
  63. def write_result(self, result: dict, file: TextIO):
  64. for segment in result["segments"]:
  65. print(segment["text"].strip(), file=file, flush=True)
  66. class SubtitlesWriter(ResultWriter):
  67. always_include_hours: bool
  68. decimal_marker: str
  69. def iterate_result(self, result: dict):
  70. for segment in result["segments"]:
  71. segment_start = self.format_timestamp(segment["start"])
  72. segment_end = self.format_timestamp(segment["end"])
  73. segment_text = segment["text"].strip().replace("-->", "->")
  74. if word_timings := segment.get("words", None):
  75. all_words = [timing["word"] for timing in word_timings]
  76. all_words[0] = all_words[0].strip() # remove the leading space, if any
  77. last = segment_start
  78. for i, this_word in enumerate(word_timings):
  79. start = self.format_timestamp(this_word["start"])
  80. end = self.format_timestamp(this_word["end"])
  81. if last != start:
  82. yield last, start, segment_text
  83. yield start, end, "".join(
  84. [
  85. f"<u>{word}</u>" if j == i else word
  86. for j, word in enumerate(all_words)
  87. ]
  88. )
  89. last = end
  90. if last != segment_end:
  91. yield last, segment_end, segment_text
  92. else:
  93. yield segment_start, segment_end, segment_text
  94. def format_timestamp(self, seconds: float):
  95. return format_timestamp(
  96. seconds=seconds,
  97. always_include_hours=self.always_include_hours,
  98. decimal_marker=self.decimal_marker,
  99. )
  100. class WriteVTT(SubtitlesWriter):
  101. extension: str = "vtt"
  102. always_include_hours: bool = False
  103. decimal_marker: str = "."
  104. def write_result(self, result: dict, file: TextIO):
  105. print("WEBVTT\n", file=file)
  106. for start, end, text in self.iterate_result(result):
  107. print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
  108. class WriteSRT(SubtitlesWriter):
  109. extension: str = "srt"
  110. always_include_hours: bool = True
  111. decimal_marker: str = ","
  112. def write_result(self, result: dict, file: TextIO):
  113. for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
  114. print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
  115. class WriteTSV(ResultWriter):
  116. """
  117. Write a transcript to a file in TSV (tab-separated values) format containing lines like:
  118. <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
  119. Using integer milliseconds as start and end times means there's no chance of interference from
  120. an environment setting a language encoding that causes the decimal in a floating point number
  121. to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
  122. """
  123. extension: str = "tsv"
  124. def write_result(self, result: dict, file: TextIO):
  125. print("start", "end", "text", sep="\t", file=file)
  126. for segment in result["segments"]:
  127. print(round(1000 * segment["start"]), file=file, end="\t")
  128. print(round(1000 * segment["end"]), file=file, end="\t")
  129. print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
  130. class WriteJSON(ResultWriter):
  131. extension: str = "json"
  132. def write_result(self, result: dict, file: TextIO):
  133. json.dump(result, file)
  134. def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
  135. writers = {
  136. "txt": WriteTXT,
  137. "vtt": WriteVTT,
  138. "srt": WriteSRT,
  139. "tsv": WriteTSV,
  140. "json": WriteJSON,
  141. }
  142. if output_format == "all":
  143. all_writers = [writer(output_dir) for writer in writers.values()]
  144. def write_all(result: dict, file: TextIO):
  145. for writer in all_writers:
  146. writer(result, file)
  147. return write_all
  148. return writers[output_format](output_dir)