utils.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import json
  2. import os
  3. import sys
  4. import zlib
  5. from typing import Callable, TextIO
  6. system_encoding = sys.getdefaultencoding()
  7. if system_encoding != "utf-8":
  8. def make_safe(string):
  9. # replaces any character not representable using the system default encoding with an '?',
  10. # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
  11. return string.encode(system_encoding, errors="replace").decode(system_encoding)
  12. else:
  13. def make_safe(string):
  14. # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
  15. return string
  16. def exact_div(x, y):
  17. assert x % y == 0
  18. return x // y
  19. def str2bool(string):
  20. str2val = {"True": True, "False": False}
  21. if string in str2val:
  22. return str2val[string]
  23. else:
  24. raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
  25. def optional_int(string):
  26. return None if string == "None" else int(string)
  27. def optional_float(string):
  28. return None if string == "None" else float(string)
  29. def compression_ratio(text) -> float:
  30. text_bytes = text.encode("utf-8")
  31. return len(text_bytes) / len(zlib.compress(text_bytes))
  32. def format_timestamp(
  33. seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
  34. ):
  35. assert seconds >= 0, "non-negative timestamp expected"
  36. milliseconds = round(seconds * 1000.0)
  37. hours = milliseconds // 3_600_000
  38. milliseconds -= hours * 3_600_000
  39. minutes = milliseconds // 60_000
  40. milliseconds -= minutes * 60_000
  41. seconds = milliseconds // 1_000
  42. milliseconds -= seconds * 1_000
  43. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  44. return (
  45. f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  46. )
  47. class ResultWriter:
  48. extension: str
  49. def __init__(self, output_dir: str):
  50. self.output_dir = output_dir
  51. def __call__(self, result: dict, audio_path: str):
  52. audio_basename = os.path.basename(audio_path)
  53. output_path = os.path.join(
  54. self.output_dir, audio_basename + "." + self.extension
  55. )
  56. with open(output_path, "w", encoding="utf-8") as f:
  57. self.write_result(result, file=f)
  58. def write_result(self, result: dict, file: TextIO):
  59. raise NotImplementedError
  60. class WriteTXT(ResultWriter):
  61. extension: str = "txt"
  62. def write_result(self, result: dict, file: TextIO):
  63. for segment in result["segments"]:
  64. print(segment["text"].strip(), file=file, flush=True)
  65. class SubtitlesWriter(ResultWriter):
  66. always_include_hours: bool
  67. decimal_marker: str
  68. def iterate_result(self, result: dict):
  69. for segment in result["segments"]:
  70. segment_start = self.format_timestamp(segment["start"])
  71. segment_end = self.format_timestamp(segment["end"])
  72. segment_text = segment["text"].strip().replace("-->", "->")
  73. if word_timings := segment.get("words", None):
  74. all_words = [timing["word"] for timing in word_timings]
  75. all_words[0] = all_words[0].strip() # remove the leading space, if any
  76. last = segment_start
  77. for i, this_word in enumerate(word_timings):
  78. start = self.format_timestamp(this_word["start"])
  79. end = self.format_timestamp(this_word["end"])
  80. if last != start:
  81. yield last, start, segment_text
  82. yield start, end, "".join(
  83. [
  84. f"<u>{word}</u>" if j == i else word
  85. for j, word in enumerate(all_words)
  86. ]
  87. )
  88. last = end
  89. if last != segment_end:
  90. yield last, segment_end, segment_text
  91. else:
  92. yield segment_start, segment_end, segment_text
  93. def format_timestamp(self, seconds: float):
  94. return format_timestamp(
  95. seconds=seconds,
  96. always_include_hours=self.always_include_hours,
  97. decimal_marker=self.decimal_marker,
  98. )
  99. class WriteVTT(SubtitlesWriter):
  100. extension: str = "vtt"
  101. always_include_hours: bool = False
  102. decimal_marker: str = "."
  103. def write_result(self, result: dict, file: TextIO):
  104. print("WEBVTT\n", file=file)
  105. for start, end, text in self.iterate_result(result):
  106. print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
  107. class WriteSRT(SubtitlesWriter):
  108. extension: str = "srt"
  109. always_include_hours: bool = True
  110. decimal_marker: str = ","
  111. def write_result(self, result: dict, file: TextIO):
  112. for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
  113. print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
  114. class WriteTSV(ResultWriter):
  115. """
  116. Write a transcript to a file in TSV (tab-separated values) format containing lines like:
  117. <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
  118. Using integer milliseconds as start and end times means there's no chance of interference from
  119. an environment setting a language encoding that causes the decimal in a floating point number
  120. to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
  121. """
  122. extension: str = "tsv"
  123. def write_result(self, result: dict, file: TextIO):
  124. print("start", "end", "text", sep="\t", file=file)
  125. for segment in result["segments"]:
  126. print(round(1000 * segment["start"]), file=file, end="\t")
  127. print(round(1000 * segment["end"]), file=file, end="\t")
  128. print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
  129. class WriteJSON(ResultWriter):
  130. extension: str = "json"
  131. def write_result(self, result: dict, file: TextIO):
  132. json.dump(result, file)
  133. def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
  134. writers = {
  135. "txt": WriteTXT,
  136. "vtt": WriteVTT,
  137. "srt": WriteSRT,
  138. "tsv": WriteTSV,
  139. "json": WriteJSON,
  140. }
  141. if output_format == "all":
  142. all_writers = [writer(output_dir) for writer in writers.values()]
  143. def write_all(result: dict, file: TextIO):
  144. for writer in all_writers:
  145. writer(result, file)
  146. return write_all
  147. return writers[output_format](output_dir)