|
@@ -74,7 +74,9 @@ class ResultWriter:
|
|
def __init__(self, output_dir: str):
|
|
def __init__(self, output_dir: str):
|
|
self.output_dir = output_dir
|
|
self.output_dir = output_dir
|
|
|
|
|
|
- def __call__(self, result: dict, audio_path: str, options: dict):
|
|
|
|
|
|
+ def __call__(
|
|
|
|
+ self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
audio_basename = os.path.basename(audio_path)
|
|
audio_basename = os.path.basename(audio_path)
|
|
audio_basename = os.path.splitext(audio_basename)[0]
|
|
audio_basename = os.path.splitext(audio_basename)[0]
|
|
output_path = os.path.join(
|
|
output_path = os.path.join(
|
|
@@ -82,16 +84,20 @@ class ResultWriter:
|
|
)
|
|
)
|
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
- self.write_result(result, file=f, options=options)
|
|
|
|
|
|
+ self.write_result(result, file=f, options=options, **kwargs)
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
|
|
|
+ def write_result(
|
|
|
|
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
class WriteTXT(ResultWriter):
|
|
class WriteTXT(ResultWriter):
|
|
extension: str = "txt"
|
|
extension: str = "txt"
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
|
|
|
+ def write_result(
|
|
|
|
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
for segment in result["segments"]:
|
|
for segment in result["segments"]:
|
|
print(segment["text"].strip(), file=file, flush=True)
|
|
print(segment["text"].strip(), file=file, flush=True)
|
|
|
|
|
|
@@ -100,12 +106,24 @@ class SubtitlesWriter(ResultWriter):
|
|
always_include_hours: bool
|
|
always_include_hours: bool
|
|
decimal_marker: str
|
|
decimal_marker: str
|
|
|
|
|
|
- def iterate_result(self, result: dict, options: dict):
|
|
|
|
- raw_max_line_width: Optional[int] = options["max_line_width"]
|
|
|
|
- max_line_count: Optional[int] = options["max_line_count"]
|
|
|
|
- highlight_words: bool = options["highlight_words"]
|
|
|
|
- max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
|
|
|
- preserve_segments = max_line_count is None or raw_max_line_width is None
|
|
|
|
|
|
+ def iterate_result(
|
|
|
|
+ self,
|
|
|
|
+ result: dict,
|
|
|
|
+ options: Optional[dict] = None,
|
|
|
|
+ *,
|
|
|
|
+ max_line_width: Optional[int] = None,
|
|
|
|
+ max_line_count: Optional[int] = None,
|
|
|
|
+ highlight_words: bool = False,
|
|
|
|
+ max_words_per_line: Optional[int] = None,
|
|
|
|
+ ):
|
|
|
|
+ options = options or {}
|
|
|
|
+ max_line_width = max_line_width or options.get("max_line_width")
|
|
|
|
+ max_line_count = max_line_count or options.get("max_line_count")
|
|
|
|
+ highlight_words = highlight_words or options.get("highlight_words", False)
|
|
|
|
+ max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
|
|
|
+ preserve_segments = max_line_count is None or max_line_width is None
|
|
|
|
+ max_line_width = max_line_width or 1000
|
|
|
|
+ max_words_per_line = max_words_per_line or 1000
|
|
|
|
|
|
def iterate_subtitles():
|
|
def iterate_subtitles():
|
|
line_len = 0
|
|
line_len = 0
|
|
@@ -114,34 +132,50 @@ class SubtitlesWriter(ResultWriter):
|
|
subtitle: list[dict] = []
|
|
subtitle: list[dict] = []
|
|
last = result["segments"][0]["words"][0]["start"]
|
|
last = result["segments"][0]["words"][0]["start"]
|
|
for segment in result["segments"]:
|
|
for segment in result["segments"]:
|
|
- for i, original_timing in enumerate(segment["words"]):
|
|
|
|
- timing = original_timing.copy()
|
|
|
|
- long_pause = not preserve_segments and timing["start"] - last > 3.0
|
|
|
|
- has_room = line_len + len(timing["word"]) <= max_line_width
|
|
|
|
- seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
|
|
|
- if line_len > 0 and has_room and not long_pause and not seg_break:
|
|
|
|
- # line continuation
|
|
|
|
- line_len += len(timing["word"])
|
|
|
|
- else:
|
|
|
|
- # new line
|
|
|
|
- timing["word"] = timing["word"].strip()
|
|
|
|
|
|
+ chunk_index = 0
|
|
|
|
+ words_count = max_words_per_line
|
|
|
|
+ while chunk_index < len(segment["words"]):
|
|
|
|
+ remaining_words = len(segment["words"]) - chunk_index
|
|
|
|
+ if max_words_per_line > len(segment["words"]) - chunk_index:
|
|
|
|
+ words_count = remaining_words
|
|
|
|
+ for i, original_timing in enumerate(
|
|
|
|
+ segment["words"][chunk_index : chunk_index + words_count]
|
|
|
|
+ ):
|
|
|
|
+ timing = original_timing.copy()
|
|
|
|
+ long_pause = (
|
|
|
|
+ not preserve_segments and timing["start"] - last > 3.0
|
|
|
|
+ )
|
|
|
|
+ has_room = line_len + len(timing["word"]) <= max_line_width
|
|
|
|
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
|
if (
|
|
if (
|
|
- len(subtitle) > 0
|
|
|
|
- and max_line_count is not None
|
|
|
|
- and (long_pause or line_count >= max_line_count)
|
|
|
|
- or seg_break
|
|
|
|
|
|
+ line_len > 0
|
|
|
|
+ and has_room
|
|
|
|
+ and not long_pause
|
|
|
|
+ and not seg_break
|
|
):
|
|
):
|
|
- # subtitle break
|
|
|
|
- yield subtitle
|
|
|
|
- subtitle = []
|
|
|
|
- line_count = 1
|
|
|
|
- elif line_len > 0:
|
|
|
|
- # line break
|
|
|
|
- line_count += 1
|
|
|
|
- timing["word"] = "\n" + timing["word"]
|
|
|
|
- line_len = len(timing["word"].strip())
|
|
|
|
- subtitle.append(timing)
|
|
|
|
- last = timing["start"]
|
|
|
|
|
|
+ # line continuation
|
|
|
|
+ line_len += len(timing["word"])
|
|
|
|
+ else:
|
|
|
|
+ # new line
|
|
|
|
+ timing["word"] = timing["word"].strip()
|
|
|
|
+ if (
|
|
|
|
+ len(subtitle) > 0
|
|
|
|
+ and max_line_count is not None
|
|
|
|
+ and (long_pause or line_count >= max_line_count)
|
|
|
|
+ or seg_break
|
|
|
|
+ ):
|
|
|
|
+ # subtitle break
|
|
|
|
+ yield subtitle
|
|
|
|
+ subtitle = []
|
|
|
|
+ line_count = 1
|
|
|
|
+ elif line_len > 0:
|
|
|
|
+ # line break
|
|
|
|
+ line_count += 1
|
|
|
|
+ timing["word"] = "\n" + timing["word"]
|
|
|
|
+ line_len = len(timing["word"].strip())
|
|
|
|
+ subtitle.append(timing)
|
|
|
|
+ last = timing["start"]
|
|
|
|
+ chunk_index += max_words_per_line
|
|
if len(subtitle) > 0:
|
|
if len(subtitle) > 0:
|
|
yield subtitle
|
|
yield subtitle
|
|
|
|
|
|
@@ -190,9 +224,11 @@ class WriteVTT(SubtitlesWriter):
|
|
always_include_hours: bool = False
|
|
always_include_hours: bool = False
|
|
decimal_marker: str = "."
|
|
decimal_marker: str = "."
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
|
|
|
+ def write_result(
|
|
|
|
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
print("WEBVTT\n", file=file)
|
|
print("WEBVTT\n", file=file)
|
|
- for start, end, text in self.iterate_result(result, options):
|
|
|
|
|
|
+ for start, end, text in self.iterate_result(result, options, **kwargs):
|
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
|
|
|
|
|
|
|
|
|
@@ -201,9 +237,11 @@ class WriteSRT(SubtitlesWriter):
|
|
always_include_hours: bool = True
|
|
always_include_hours: bool = True
|
|
decimal_marker: str = ","
|
|
decimal_marker: str = ","
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
|
|
|
+ def write_result(
|
|
|
|
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
for i, (start, end, text) in enumerate(
|
|
for i, (start, end, text) in enumerate(
|
|
- self.iterate_result(result, options), start=1
|
|
|
|
|
|
+ self.iterate_result(result, options, **kwargs), start=1
|
|
):
|
|
):
|
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
|
|
|
|
|
@@ -220,7 +258,9 @@ class WriteTSV(ResultWriter):
|
|
|
|
|
|
extension: str = "tsv"
|
|
extension: str = "tsv"
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
|
|
|
+ def write_result(
|
|
|
|
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
print("start", "end", "text", sep="\t", file=file)
|
|
print("start", "end", "text", sep="\t", file=file)
|
|
for segment in result["segments"]:
|
|
for segment in result["segments"]:
|
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
|
@@ -231,7 +271,9 @@ class WriteTSV(ResultWriter):
|
|
class WriteJSON(ResultWriter):
|
|
class WriteJSON(ResultWriter):
|
|
extension: str = "json"
|
|
extension: str = "json"
|
|
|
|
|
|
- def write_result(self, result: dict, file: TextIO, options: dict):
|
|
|
|
|
|
+ def write_result(
|
|
|
|
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
json.dump(result, file)
|
|
json.dump(result, file)
|
|
|
|
|
|
|
|
|
|
@@ -249,9 +291,11 @@ def get_writer(
|
|
if output_format == "all":
|
|
if output_format == "all":
|
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
|
|
|
|
|
- def write_all(result: dict, file: TextIO, options: dict):
|
|
|
|
|
|
+ def write_all(
|
|
|
|
+ result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
|
|
|
+ ):
|
|
for writer in all_writers:
|
|
for writer in all_writers:
|
|
- writer(result, file, options)
|
|
|
|
|
|
+ writer(result, file, options, **kwargs)
|
|
|
|
|
|
return write_all
|
|
return write_all
|
|
|
|
|