@@ -170,6 +170,9 @@ def find_alignment(
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
+ if len(text_tokens) == 0:
+ return []
+
tokens = torch.tensor(
[
*tokenizer.sot_sequence,