|
@@ -1,8 +1,9 @@
|
|
|
import json
|
|
|
import os
|
|
|
+import re
|
|
|
import sys
|
|
|
import zlib
|
|
|
-from typing import Callable, TextIO
|
|
|
+from typing import Callable, Optional, TextIO
|
|
|
|
|
|
system_encoding = sys.getdefaultencoding()
|
|
|
|
|
@@ -73,7 +74,7 @@ class ResultWriter:
|
|
|
def __init__(self, output_dir: str):
|
|
|
self.output_dir = output_dir
|
|
|
|
|
|
- def __call__(self, result: dict, audio_path: str):
|
|
|
+ def __call__(self, result: dict, audio_path: str, options: dict):
|
|
|
audio_basename = os.path.basename(audio_path)
|
|
|
audio_basename = os.path.splitext(audio_basename)[0]
|
|
|
output_path = os.path.join(
|
|
@@ -81,16 +82,16 @@ class ResultWriter:
|
|
|
)
|
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
|
- self.write_result(result, file=f)
|
|
|
+ self.write_result(result, file=f, options=options)
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO):
|
|
|
+ def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
class WriteTXT(ResultWriter):
|
|
|
extension: str = "txt"
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO):
|
|
|
+ def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
for segment in result["segments"]:
|
|
|
print(segment["text"].strip(), file=file, flush=True)
|
|
|
|
|
@@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter):
|
|
|
always_include_hours: bool
|
|
|
decimal_marker: str
|
|
|
|
|
|
- def iterate_result(self, result: dict):
|
|
|
- for segment in result["segments"]:
|
|
|
- segment_start = self.format_timestamp(segment["start"])
|
|
|
- segment_end = self.format_timestamp(segment["end"])
|
|
|
- segment_text = segment["text"].strip().replace("-->", "->")
|
|
|
-
|
|
|
- if word_timings := segment.get("words", None):
|
|
|
- all_words = [timing["word"] for timing in word_timings]
|
|
|
- all_words[0] = all_words[0].strip() # remove the leading space, if any
|
|
|
- last = segment_start
|
|
|
- for i, this_word in enumerate(word_timings):
|
|
|
- start = self.format_timestamp(this_word["start"])
|
|
|
- end = self.format_timestamp(this_word["end"])
|
|
|
- if last != start:
|
|
|
- yield last, start, segment_text
|
|
|
-
|
|
|
- yield start, end, "".join(
|
|
|
- [
|
|
|
- f"<u>{word}</u>" if j == i else word
|
|
|
- for j, word in enumerate(all_words)
|
|
|
- ]
|
|
|
- )
|
|
|
- last = end
|
|
|
-
|
|
|
- if last != segment_end:
|
|
|
- yield last, segment_end, segment_text
|
|
|
- else:
|
|
|
+ def iterate_result(self, result: dict, options: dict):
|
|
|
+ raw_max_line_width: Optional[int] = options["max_line_width"]
|
|
|
+ max_line_count: Optional[int] = options["max_line_count"]
|
|
|
+ highlight_words: bool = options["highlight_words"]
|
|
|
+ max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
|
|
+ preserve_segments = max_line_count is None or raw_max_line_width is None
|
|
|
+
|
|
|
+ def iterate_subtitles():
|
|
|
+ line_len = 0
|
|
|
+ line_count = 1
|
|
|
+ # the next subtitle to yield (a list of word timings with whitespace)
|
|
|
+ subtitle: list[dict] = []
|
|
|
+ last = result["segments"][0]["words"][0]["start"]
|
|
|
+ for segment in result["segments"]:
|
|
|
+ for i, original_timing in enumerate(segment["words"]):
|
|
|
+ timing = original_timing.copy()
|
|
|
+ long_pause = not preserve_segments and timing["start"] - last > 3.0
|
|
|
+ has_room = line_len + len(timing["word"]) <= max_line_width
|
|
|
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
|
|
+ if line_len > 0 and has_room and not long_pause and not seg_break:
|
|
|
+ # line continuation
|
|
|
+ line_len += len(timing["word"])
|
|
|
+ else:
|
|
|
+ # new line
|
|
|
+ timing["word"] = timing["word"].strip()
|
|
|
+ if (
|
|
|
+ len(subtitle) > 0
|
|
|
+ and max_line_count is not None
|
|
|
+ and (long_pause or line_count >= max_line_count)
|
|
|
+ or seg_break
|
|
|
+ ):
|
|
|
+ # subtitle break
|
|
|
+ yield subtitle
|
|
|
+ subtitle = []
|
|
|
+ line_count = 1
|
|
|
+ elif line_len > 0:
|
|
|
+ # line break
|
|
|
+ line_count += 1
|
|
|
+ timing["word"] = "\n" + timing["word"]
|
|
|
+ line_len = len(timing["word"].strip())
|
|
|
+ subtitle.append(timing)
|
|
|
+ last = timing["start"]
|
|
|
+ if len(subtitle) > 0:
|
|
|
+ yield subtitle
|
|
|
+
|
|
|
+ if "words" in result["segments"][0]:
|
|
|
+ for subtitle in iterate_subtitles():
|
|
|
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
|
|
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
|
|
+ subtitle_text = "".join([word["word"] for word in subtitle])
|
|
|
+ if highlight_words:
|
|
|
+ last = subtitle_start
|
|
|
+ all_words = [timing["word"] for timing in subtitle]
|
|
|
+ for i, this_word in enumerate(subtitle):
|
|
|
+ start = self.format_timestamp(this_word["start"])
|
|
|
+ end = self.format_timestamp(this_word["end"])
|
|
|
+ if last != start:
|
|
|
+ yield last, start, subtitle_text
|
|
|
+
|
|
|
+ yield start, end, "".join(
|
|
|
+ [
|
|
|
+ re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
|
|
+ if j == i
|
|
|
+ else word
|
|
|
+ for j, word in enumerate(all_words)
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ last = end
|
|
|
+ else:
|
|
|
+ yield subtitle_start, subtitle_end, subtitle_text
|
|
|
+ else:
|
|
|
+ for segment in result["segments"]:
|
|
|
+ segment_start = self.format_timestamp(segment["start"])
|
|
|
+ segment_end = self.format_timestamp(segment["end"])
|
|
|
+ segment_text = segment["text"].strip().replace("-->", "->")
|
|
|
yield segment_start, segment_end, segment_text
|
|
|
|
|
|
def format_timestamp(self, seconds: float):
|
|
@@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter):
|
|
|
always_include_hours: bool = False
|
|
|
decimal_marker: str = "."
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO):
|
|
|
+ def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
print("WEBVTT\n", file=file)
|
|
|
- for start, end, text in self.iterate_result(result):
|
|
|
+ for start, end, text in self.iterate_result(result, options):
|
|
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
|
|
|
|
|
|
|
@@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter):
|
|
|
always_include_hours: bool = True
|
|
|
decimal_marker: str = ","
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO):
|
|
|
- for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
|
|
+ def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
+ for i, (start, end, text) in enumerate(
|
|
|
+ self.iterate_result(result, options), start=1
|
|
|
+ ):
|
|
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
|
|
|
|
|
|
|
@@ -169,7 +220,7 @@ class WriteTSV(ResultWriter):
|
|
|
|
|
|
extension: str = "tsv"
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO):
|
|
|
+ def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
print("start", "end", "text", sep="\t", file=file)
|
|
|
for segment in result["segments"]:
|
|
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
|
@@ -180,11 +231,13 @@ class WriteTSV(ResultWriter):
|
|
|
class WriteJSON(ResultWriter):
|
|
|
extension: str = "json"
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO):
|
|
|
+ def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
json.dump(result, file)
|
|
|
|
|
|
|
|
|
-def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
|
|
+def get_writer(
|
|
|
+ output_format: str, output_dir: str
|
|
|
+) -> Callable[[dict, TextIO, dict], None]:
|
|
|
writers = {
|
|
|
"txt": WriteTXT,
|
|
|
"vtt": WriteVTT,
|
|
@@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
|
|
|
if output_format == "all":
|
|
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
|
|
|
|
|
- def write_all(result: dict, file: TextIO):
|
|
|
+ def write_all(result: dict, file: TextIO, options: dict):
|
|
|
for writer in all_writers:
|
|
|
- writer(result, file)
|
|
|
+ writer(result, file, options)
|
|
|
|
|
|
return write_all
|
|
|
|