|
@@ -1,3 +1,4 @@
|
|
|
|
+import itertools
|
|
import subprocess
|
|
import subprocess
|
|
import warnings
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
@@ -290,34 +291,41 @@ def add_word_timestamps(
|
|
if len(segments) == 0:
|
|
if len(segments) == 0:
|
|
return
|
|
return
|
|
|
|
|
|
- text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
|
|
|
|
|
|
+ text_tokens_per_segment = [
|
|
|
|
+ [token for token in segment["tokens"] if token < tokenizer.eot]
|
|
|
|
+ for segment in segments
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
|
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
|
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
|
|
|
|
|
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
|
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
|
- segment_lengths = [len(s["tokens"]) for s in segments]
|
|
|
|
- token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
|
|
|
|
-
|
|
|
|
- for segment in segments:
|
|
|
|
- segment["words"] = []
|
|
|
|
-
|
|
|
|
- word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
|
|
|
|
- for i, timing in enumerate(alignment):
|
|
|
|
- if timing.word:
|
|
|
|
- segment = segments[token_sources[word_boundaries[i]]]
|
|
|
|
- start = round(time_offset + timing.start, 2)
|
|
|
|
- end = round(time_offset + timing.end, 2)
|
|
|
|
- segment["words"].append(
|
|
|
|
- dict(
|
|
|
|
- word=timing.word,
|
|
|
|
- start=start,
|
|
|
|
- end=end,
|
|
|
|
- probability=timing.probability,
|
|
|
|
|
|
+ word_index = 0
|
|
|
|
+
|
|
|
|
+ for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
|
|
|
+ saved_tokens = 0
|
|
|
|
+ words = []
|
|
|
|
+
|
|
|
|
+ while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
|
|
|
+ timing = alignment[word_index]
|
|
|
|
+
|
|
|
|
+ if timing.word:
|
|
|
|
+ words.append(
|
|
|
|
+ dict(
|
|
|
|
+ word=timing.word,
|
|
|
|
+ start=round(time_offset + timing.start, 2),
|
|
|
|
+ end=round(time_offset + timing.end, 2),
|
|
|
|
+ probability=timing.probability,
|
|
|
|
+ )
|
|
)
|
|
)
|
|
- )
|
|
|
|
|
|
|
|
- for segment in segments:
|
|
|
|
- if len(words := segment["words"]) > 0:
|
|
|
|
|
|
+ saved_tokens += len(timing.tokens)
|
|
|
|
+ word_index += 1
|
|
|
|
+
|
|
|
|
+ if len(words) > 0:
|
|
# adjust the segment-level timestamps based on the word-level timestamps
|
|
# adjust the segment-level timestamps based on the word-level timestamps
|
|
segment["start"] = words[0]["start"]
|
|
segment["start"] = words[0]["start"]
|
|
segment["end"] = words[-1]["end"]
|
|
segment["end"] = words[-1]["end"]
|
|
|
|
+
|
|
|
|
+ segment["words"] = words
|