Bläddra i källkod

word-level timestamps in `transcribe()` (#869)

* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <ryan@ryanheise.com>
Jong Wook Kim 2 år sedan
förälder
incheckning
500d0fe966

+ 2 - 3
.github/workflows/test.yml

@@ -21,6 +21,5 @@ jobs:
       - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
       - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
-      - run: pip install pytest
-      - run: pip install .
-      - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
+      - run: pip install .["dev"]
+      - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

+ 1 - 0
requirements.txt

@@ -1,3 +1,4 @@
+numba
 numpy
 numpy
 torch
 torch
 tqdm
 tqdm

+ 18 - 2
setup.py

@@ -1,4 +1,5 @@
 import os
 import os
+import sys
 
 
 import pkg_resources
 import pkg_resources
 from setuptools import setup, find_packages
 from setuptools import setup, find_packages
@@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
     return locals()["__version__"]
     return locals()["__version__"]
 
 
 
 
+requirements = []
+if sys.platform.startswith("linux"):
+    triton_requirement = "triton>=2.0.0.dev20221202"
+    try:
+        import re
+        import subprocess
+        version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
+        major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
+        if (int(major), int(minor)) < (11, 4):
+            # the last version supporting CUDA < 11.4
+            triton_requirement = "triton==2.0.0.dev20221011"
+    except (IndexError, OSError, subprocess.SubprocessError):
+        pass
+    requirements.append(triton_requirement)
+
 setup(
 setup(
     name="openai-whisper",
     name="openai-whisper",
     py_modules=["whisper"],
     py_modules=["whisper"],
@@ -22,7 +38,7 @@ setup(
     url="https://github.com/openai/whisper",
     url="https://github.com/openai/whisper",
     license="MIT",
     license="MIT",
     packages=find_packages(exclude=["tests*"]),
     packages=find_packages(exclude=["tests*"]),
-    install_requires=[
+    install_requires=requirements + [
         str(r)
         str(r)
         for r in pkg_resources.parse_requirements(
         for r in pkg_resources.parse_requirements(
             open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
             open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -32,5 +48,5 @@ setup(
         "console_scripts": ["whisper=whisper.transcribe:cli"],
         "console_scripts": ["whisper=whisper.transcribe:cli"],
     },
     },
     include_package_data=True,
     include_package_data=True,
-    extras_require={"dev": ["pytest"]},
+    extras_require={"dev": ["pytest", "scipy"]},
 )
 )

+ 14 - 0
tests/conftest.py

@@ -0,0 +1,14 @@
+import random as rand
+
+import numpy
+import pytest
+
+
+def pytest_configure(config):
+    config.addinivalue_line("markers", "requires_cuda")
+
+
+@pytest.fixture
+def random():
+    rand.seed(42)
+    numpy.random.seed(42)

+ 87 - 0
tests/test_timing.py

@@ -0,0 +1,87 @@
+import pytest
+import numpy as np
+import scipy.ndimage
+import torch
+
+from whisper.timing import dtw_cpu, dtw_cuda, median_filter
+
+
+sizes = [
+    (10, 20), (32, 16), (123, 1500), (234, 189),
+]
+shapes = [
+    (10,), (1, 15),  (4, 5, 345), (6, 12, 240, 512),
+]
+
+
+@pytest.mark.parametrize("N, M", sizes)
+def test_dtw(N: int, M: int):
+    steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
+    np.random.shuffle(steps)
+    x = np.random.random((N, M)).astype(np.float32)
+
+    i, j, k = 0, 0, 0
+    trace = []
+    while True:
+        x[i, j] -= 1
+        trace.append((i, j))
+
+        if k == len(steps):
+            break
+
+        if k + 1 < len(steps) and steps[k] != steps[k + 1]:
+            i += 1
+            j += 1
+            k += 2
+            continue
+
+        if steps[k] == 0:
+            i += 1
+        if steps[k] == 1:
+            j += 1
+        k += 1
+
+    trace = np.array(trace).T
+    dtw_trace = dtw_cpu(x)
+
+    assert np.allclose(trace, dtw_trace)
+
+
+@pytest.mark.requires_cuda
+@pytest.mark.parametrize("N, M", sizes)
+def test_dtw_cuda_equivalence(N: int, M: int):
+    x_numpy = np.random.randn(N, M).astype(np.float32)
+    x_cuda = torch.from_numpy(x_numpy).cuda()
+
+    trace_cpu = dtw_cpu(x_numpy)
+    trace_cuda = dtw_cuda(x_cuda)
+
+    assert np.allclose(trace_cpu, trace_cuda)
+
+
+@pytest.mark.parametrize("shape", shapes)
+def test_median_filter(shape):
+    x = torch.randn(*shape)
+
+    for filter_width in [3, 5, 7, 13]:
+        filtered = median_filter(x, filter_width)
+
+        # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
+        pad_width = filter_width // 2
+        padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
+        scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
+        scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
+
+        assert np.allclose(filtered, scipy_filtered)
+
+
+@pytest.mark.requires_cuda
+@pytest.mark.parametrize("shape", shapes)
+def test_median_filter_equivalence(shape):
+    x = torch.randn(*shape)
+
+    for filter_width in [3, 5, 7, 13]:
+        filtered_cpu = median_filter(x, filter_width)
+        filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
+
+        assert np.allclose(filtered_cpu, filtered_gpu)

+ 13 - 1
tests/test_transcribe.py

@@ -13,10 +13,22 @@ def test_transcribe(model_name: str):
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
 
 
     language = "en" if model_name.endswith(".en") else None
     language = "en" if model_name.endswith(".en") else None
-    result = model.transcribe(audio_path, language=language, temperature=0.0)
+    result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
     assert result["language"] == "en"
     assert result["language"] == "en"
 
 
     transcription = result["text"].lower()
     transcription = result["text"].lower()
     assert "my fellow americans" in transcription
     assert "my fellow americans" in transcription
     assert "your country" in transcription
     assert "your country" in transcription
     assert "do for you" in transcription
     assert "do for you" in transcription
+
+    timing_checked = False
+    for segment in result["segments"]:
+        for timing in segment["words"]:
+            assert timing["start"] < timing["end"]
+            if timing["word"].strip(" ,") == "Americans":
+                assert timing["start"] <= 1.8
+                assert timing["end"] >= 1.8
+                print(timing)
+                timing_checked = True
+
+    assert timing_checked

+ 22 - 0
whisper/__init__.py

@@ -29,6 +29,23 @@ _MODELS = {
     "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
     "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
 }
 }
 
 
+# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
+# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
+_ALIGNMENT_HEADS = {
+    "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
+    "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
+    "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
+    "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
+    "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
+    "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
+    "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
+    "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
+    "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
+    "large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
+    "large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
+}
+
+
 
 
 def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
 def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
     os.makedirs(root, exist_ok=True)
     os.makedirs(root, exist_ok=True)
@@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
 
 
     if name in _MODELS:
     if name in _MODELS:
         checkpoint_file = _download(_MODELS[name], download_root, in_memory)
         checkpoint_file = _download(_MODELS[name], download_root, in_memory)
+        alignment_heads = _ALIGNMENT_HEADS[name]
     elif os.path.isfile(name):
     elif os.path.isfile(name):
         checkpoint_file = open(name, "rb").read() if in_memory else name
         checkpoint_file = open(name, "rb").read() if in_memory else name
+        alignment_heads = None
     else:
     else:
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
 
 
@@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
     model = Whisper(dims)
     model = Whisper(dims)
     model.load_state_dict(checkpoint["model_state_dict"])
     model.load_state_dict(checkpoint["model_state_dict"])
 
 
+    if alignment_heads is not None:
+        model.set_alignment_heads(alignment_heads)
+
     return model.to(device)
     return model.to(device)

+ 4 - 0
whisper/audio.py

@@ -18,6 +18,10 @@ CHUNK_LENGTH = 30
 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
 N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000: number of frames in a mel spectrogram input
 N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000: number of frames in a mel spectrogram input
 
 
+N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 100 mel frames in 1s (10ms each)
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 50 audio tokens in 1s (20ms each)
+
 
 
 def load_audio(file: str, sr: int = SAMPLE_RATE):
 def load_audio(file: str, sr: int = SAMPLE_RATE):
     """
     """

+ 12 - 1
whisper/model.py

@@ -1,3 +1,5 @@
+import base64
+import gzip
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Dict
 from typing import Dict
 from typing import Iterable, Optional
 from typing import Iterable, Optional
@@ -8,8 +10,8 @@ import torch.nn.functional as F
 from torch import Tensor
 from torch import Tensor
 from torch import nn
 from torch import nn
 
 
-from .transcribe import transcribe as transcribe_function
 from .decoding import detect_language as detect_language_function, decode as decode_function
 from .decoding import detect_language as detect_language_function, decode as decode_function
+from .transcribe import transcribe as transcribe_function
 
 
 
 
 @dataclass
 @dataclass
@@ -213,6 +215,15 @@ class Whisper(nn.Module):
             self.dims.n_text_head,
             self.dims.n_text_head,
             self.dims.n_text_layer,
             self.dims.n_text_layer,
         )
         )
+        # use the last half layers for alignment by default; see `set_alignment_heads()` below
+        all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
+        all_heads[self.dims.n_text_layer // 2:] = True
+        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+
+    def set_alignment_heads(self, dump: bytes):
+        array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
+        mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
+        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
 
 
     def embed_audio(self, mel: torch.Tensor):
     def embed_audio(self, mel: torch.Tensor):
         return self.encoder(mel)
         return self.encoder(mel)

+ 305 - 0
whisper/timing.py

@@ -0,0 +1,305 @@
+import subprocess
+import warnings
+from dataclasses import dataclass
+from typing import List, TYPE_CHECKING
+
+import numba
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
+from .tokenizer import Tokenizer
+
+if TYPE_CHECKING:
+    from .model import Whisper
+
+
+def median_filter(x: torch.Tensor, filter_width: int):
+    """Apply a median filter of width `filter_width` along the last dimension of `x`"""
+    pad_width = filter_width // 2
+    if x.shape[-1] <= pad_width:
+        # F.pad requires the padding width to be smaller than the input dimension
+        return x
+
+    if (ndim := x.ndim) <= 2:
+        # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
+        x = x[None, None, :]
+
+    assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
+
+    result = None
+    x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
+    if x.is_cuda:
+        try:
+            from .triton_ops import median_filter_cuda
+            result = median_filter_cuda(x, filter_width)
+        except (RuntimeError, subprocess.CalledProcessError):
+            warnings.warn(
+                "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
+                "falling back to a slower median kernel implementation..."
+            )
+
+    if result is None:
+        # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
+        result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
+
+    if ndim <= 2:
+        result = result[0, 0]
+
+    return result
+
+@numba.jit
+def backtrace(trace: np.ndarray):
+    i = trace.shape[0] - 1
+    j = trace.shape[1] - 1
+    trace[0, :] = 2
+    trace[:, 0] = 1
+
+    result = []
+    while i > 0 or j > 0:
+        result.append((i - 1, j - 1))
+
+        if trace[i, j] == 0:
+            i -= 1
+            j -= 1
+        elif trace[i, j] == 1:
+            i -= 1
+        elif trace[i, j] == 2:
+            j -= 1
+        else:
+            raise ValueError("Unexpected trace[i, j]")
+
+    result = np.array(result)
+    return result[::-1, :].T
+
+
+@numba.jit(nopython=True, parallel=True)
+def dtw_cpu(x: np.ndarray):
+    N, M = x.shape
+    cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
+    trace = -np.ones((N + 1, M + 1), dtype=np.float32)
+
+    cost[0, 0] = 0
+    for j in range(1, M + 1):
+        for i in range(1, N + 1):
+            c0 = cost[i - 1, j - 1]
+            c1 = cost[i - 1, j]
+            c2 = cost[i, j - 1]
+
+            if c0 < c1 and c0 < c2:
+                c, t = c0, 0
+            elif c1 < c0 and c1 < c2:
+                c, t = c1, 1
+            else:
+                c, t = c2, 2
+
+            cost[i, j] = x[i - 1, j - 1] + c
+            trace[i, j] = t
+
+    return backtrace(trace)
+
+
+def dtw_cuda(x, BLOCK_SIZE=1024):
+    from .triton_ops import dtw_kernel
+
+    M, N = x.shape
+    assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
+
+    x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
+    x_skew = x_skew.T.contiguous()
+    cost = torch.ones(N + M + 2, M + 2) * np.inf
+    cost[0, 0] = 0
+    cost = cost.cuda()
+    trace = torch.zeros_like(cost, dtype=torch.int32)
+
+    dtw_kernel[(1,)](
+        cost,
+        trace,
+        x_skew,
+        x_skew.stride(0),
+        cost.stride(0),
+        trace.stride(0),
+        N,
+        M,
+        BLOCK_SIZE=BLOCK_SIZE
+    )
+
+    trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
+    return backtrace(trace.cpu().numpy())
+
+
+def dtw(x: torch.Tensor) -> np.ndarray:
+    if x.is_cuda:
+        try:
+            return dtw_cuda(x)
+        except (RuntimeError, subprocess.CalledProcessError):
+            warnings.warn(
+                "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
+                "falling back to a slower DTW implementation..."
+            )
+
+    return dtw_cpu(x.double().cpu().numpy())
+
+
+@dataclass
+class WordTiming:
+    word: str
+    tokens: List[int]
+    start: float
+    end: float
+    probability: float
+
+
+def find_alignment(
+    model: "Whisper",
+    tokenizer: Tokenizer,
+    text_tokens: List[int],
+    mel: torch.Tensor,
+    num_frames: int,
+    *,
+    medfilt_width: int = 7,
+    qk_scale: float = 1.0,
+) -> List[WordTiming]:
+    tokens = torch.tensor(
+        [
+            *tokenizer.sot_sequence,
+            tokenizer.no_timestamps,
+            *text_tokens,
+            tokenizer.eot,
+        ]
+    ).to(model.device)
+
+    # install hooks on the cross attention layers to retrieve the attention weights
+    QKs = [None] * model.dims.n_text_layer
+    hooks = [
+        block.cross_attn.register_forward_hook(
+            lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
+        )
+        for i, block in enumerate(model.decoder.blocks)
+    ]
+
+    with torch.no_grad():
+        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
+        token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
+        text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
+
+    for hook in hooks:
+        hook.remove()
+
+    # heads * tokens * frames
+    weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
+    weights = weights[:, :, : num_frames // 2]
+    weights = (weights * qk_scale).softmax(dim=-1)
+    std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
+    weights = (weights - mean) / std
+    weights = median_filter(weights, medfilt_width)
+
+    matrix = weights.mean(axis=0)
+    matrix = matrix[len(tokenizer.sot_sequence):-1]
+    text_indices, time_indices = dtw(-matrix)
+
+    words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
+    word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
+
+    jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
+    jump_times = time_indices[jumps] / TOKENS_PER_SECOND
+    start_times = jump_times[word_boundaries[:-1]]
+    end_times = jump_times[word_boundaries[1:]]
+    word_probabilities = [
+        np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
+    ]
+
+    # hack: ensure the first and second word is not longer than twice the median word duration.
+    # a better segmentation algorithm based on VAD should be able to replace this.
+    word_durations = end_times - start_times
+    word_durations = word_durations[word_durations.nonzero()]
+    if len(word_durations) > 0:
+        median_duration = np.median(word_durations)
+        max_duration = median_duration * 2
+        if len(word_durations) >= 2 and word_durations[1] > max_duration:
+            end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
+        if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
+            start_times[0] = max(0, end_times[0] - max_duration)
+
+    return [
+        WordTiming(word, tokens, start, end, probability)
+        for word, tokens, start, end, probability in zip(
+            words, word_tokens, start_times, end_times, word_probabilities
+        )
+    ]
+
+
+def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
+    # merge prepended punctuations
+    i = len(alignment) - 2
+    j = len(alignment) - 1
+    while i >= 0:
+        previous = alignment[i]
+        following = alignment[j]
+        if previous.word.startswith(" ") and previous.word.strip() in prepended:
+            # prepend it to the following word
+            following.word = previous.word + following.word
+            following.tokens = previous.tokens + following.tokens
+            previous.word = ""
+            previous.tokens = []
+        else:
+            j = i
+        i -= 1
+
+    # merge appended punctuations
+    i = 0
+    j = 1
+    while j < len(alignment):
+        previous = alignment[i]
+        following = alignment[j]
+        if not previous.word.endswith(" ") and following.word in appended:
+            # append it to the previous word
+            previous.word = previous.word + following.word
+            previous.tokens = previous.tokens + following.tokens
+            following.word = ""
+            following.tokens = []
+        else:
+            i = j
+        j += 1
+
+
+def add_word_timestamps(
+    *,
+    segments: List[dict],
+    model: "Whisper",
+    tokenizer: Tokenizer,
+    mel: torch.Tensor,
+    num_frames: int,
+    prepend_punctuations: str = "\"\'“¿([{-",
+    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
+    **hyperparams,
+):
+    if len(segments) == 0:
+        return
+
+    text_tokens = [t for segment in segments for t in segment["tokens"]]
+    alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
+    merge_punctuations(alignment, prepend_punctuations, append_punctuations)
+
+    time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
+    token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
+
+    for segment in segments:
+        segment["words"] = []
+
+    word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
+    for i, timing in enumerate(alignment):
+        if timing.word:
+            segment = segments[token_sources[word_boundaries[i]]]
+            start = round(time_offset + timing.start, 2)
+            end = round(time_offset + timing.end, 2)
+            segment["words"].append(
+                dict(word=timing.word, start=start, end=end, probability=timing.probability)
+            )
+
+    for segment in segments:
+        if len(words := segment["words"]) > 0:
+            # adjust the segment-level timestamps based on the word-level timestamps
+            segment["start"] = words[0]["start"]
+            segment["end"] = words[-1]["end"]

+ 43 - 0
whisper/tokenizer.py

@@ -1,4 +1,5 @@
 import os
 import os
+import string
 from dataclasses import dataclass
 from dataclasses import dataclass
 from functools import lru_cache, cached_property
 from functools import lru_cache, cached_property
 from typing import List, Optional, Tuple, Union
 from typing import List, Optional, Tuple, Union
@@ -265,6 +266,48 @@ class Tokenizer:
         assert len(tokens) == 1, f"{text} is not encoded as a single token"
         assert len(tokens) == 1, f"{text} is not encoded as a single token"
         return tokens[0]
         return tokens[0]
 
 
+    def split_to_word_tokens(self, tokens: List[int]):
+        if self.language in {"zh", "ja", "th", "lo", "my"}:
+            # These languages don't typically use spaces, so it is difficult to split words
+            # without morpheme analysis. Here, we instead split words at any
+            # position where the tokens are decoded as valid unicode points
+            return self.split_tokens_on_unicode(tokens)
+
+        return self.split_tokens_on_spaces(tokens)
+
+    def split_tokens_on_unicode(self, tokens: List[int]):
+        words = []
+        word_tokens = []
+        current_tokens = []
+
+        for token in tokens:
+            current_tokens.append(token)
+            decoded = self.decode_with_timestamps(current_tokens)
+            if "\ufffd" not in decoded:
+                words.append(decoded)
+                word_tokens.append(current_tokens)
+                current_tokens = []
+
+        return words, word_tokens
+
+    def split_tokens_on_spaces(self, tokens: List[int]):
+        subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
+        words = []
+        word_tokens = []
+
+        for subword, subword_tokens in zip(subwords, subword_tokens_list):
+            special = subword_tokens[0] >= self.eot
+            with_space = subword.startswith(" ")
+            punctuation = subword.strip() in string.punctuation
+            if special or with_space or punctuation or len(words) == 0:
+                words.append(subword)
+                word_tokens.append(subword_tokens)
+            else:
+                words[-1] = words[-1] + subword
+                word_tokens[-1].extend(subword_tokens)
+
+        return words, word_tokens
+
 
 
 @lru_cache(maxsize=None)
 @lru_cache(maxsize=None)
 def build_tokenizer(name: str = "gpt2"):
 def build_tokenizer(name: str = "gpt2"):

+ 108 - 52
whisper/transcribe.py

@@ -7,8 +7,9 @@ import numpy as np
 import torch
 import torch
 import tqdm
 import tqdm
 
 
-from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
+from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
 from .decoding import DecodingOptions, DecodingResult
 from .decoding import DecodingOptions, DecodingResult
+from .timing import add_word_timestamps
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
 from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
 from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
 
 
@@ -27,6 +28,9 @@ def transcribe(
     no_speech_threshold: Optional[float] = 0.6,
     no_speech_threshold: Optional[float] = 0.6,
     condition_on_previous_text: bool = True,
     condition_on_previous_text: bool = True,
     initial_prompt: Optional[str] = None,
     initial_prompt: Optional[str] = None,
+    word_timestamps: bool = False,
+    prepend_punctuations: str = "\"\'“¿([{-",
+    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
     **decode_options,
     **decode_options,
 ):
 ):
     """
     """
@@ -63,6 +67,21 @@ def transcribe(
         disabling may make the text inconsistent across windows, but the model becomes less prone to
         disabling may make the text inconsistent across windows, but the model becomes less prone to
         getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
         getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
 
 
+    word_timestamps: bool
+        Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
+        and include the timestamps for each word in each segment.
+
+    prepend_punctuations: str
+        If word_timestamps is True, merge these punctuation symbols with the next word
+
+    append_punctuations: str
+        If word_timestamps is True, merge these punctuation symbols with the previous word
+
+    initial_prompt: Optional[str]
+        Optional text to provide as a prompt for the first window. This can be used to provide, or
+        "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
+        to make it more likely to predict those word correctly.
+
     decode_options: dict
     decode_options: dict
         Keyword arguments to construct `DecodingOptions` instances
         Keyword arguments to construct `DecodingOptions` instances
 
 
@@ -90,16 +109,19 @@ def transcribe(
         else:
         else:
             if verbose:
             if verbose:
                 print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
                 print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
-            segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
-            _, probs = model.detect_language(segment)
+            mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
+            _, probs = model.detect_language(mel_segment)
             decode_options["language"] = max(probs, key=probs.get)
             decode_options["language"] = max(probs, key=probs.get)
             if verbose is not None:
             if verbose is not None:
                 print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
                 print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
 
 
-    language = decode_options["language"]
-    task = decode_options.get("task", "transcribe")
+    language: str = decode_options["language"]
+    task: str = decode_options.get("task", "transcribe")
     tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
     tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
 
 
+    if word_timestamps and task == "translate":
+        warnings.warn("Word-level timestamps on translations may not be reliable.")
+
     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
         temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
         temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
         decode_result = None
         decode_result = None
@@ -145,42 +167,35 @@ def transcribe(
     else:
     else:
         initial_prompt_tokens = []
         initial_prompt_tokens = []
 
 
-    def add_segment(
-        *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
+    def new_segment(
+        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
     ):
     ):
-        text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
-        if len(text.strip()) == 0:  # skip empty text output
-            return
-
-        all_segments.append(
-            {
-                "id": len(all_segments),
-                "seek": seek,
-                "start": start,
-                "end": end,
-                "text": text,
-                "tokens": text_tokens.tolist(),
-                "temperature": result.temperature,
-                "avg_logprob": result.avg_logprob,
-                "compression_ratio": result.compression_ratio,
-                "no_speech_prob": result.no_speech_prob,
-            }
-        )
-        if verbose:
-            print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
-
-    # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
+        text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
+        return {
+            "id": len(all_segments),
+            "seek": seek,
+            "start": start,
+            "end": end,
+            "text": tokenizer.decode(text_tokens),
+            "tokens": text_tokens,
+            "temperature": result.temperature,
+            "avg_logprob": result.avg_logprob,
+            "compression_ratio": result.compression_ratio,
+            "no_speech_prob": result.no_speech_prob,
+        }
+
+    # show the progress bar when verbose is False (if True, transcribed text will be printed)
     num_frames = mel.shape[-1]
     num_frames = mel.shape[-1]
-    previous_seek_value = seek
-
     with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
     with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
         while seek < num_frames:
         while seek < num_frames:
-            timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
-            segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
-            segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
+            time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
+            mel_segment = mel[:, seek:]
+            segment_size = min(mel_segment.shape[-1], N_FRAMES)
+            segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
+            mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
 
 
             decode_options["prompt"] = all_tokens[prompt_reset_since:]
             decode_options["prompt"] = all_tokens[prompt_reset_since:]
-            result: DecodingResult = decode_with_fallback(segment)
+            result: DecodingResult = decode_with_fallback(mel_segment)
             tokens = torch.tensor(result.tokens)
             tokens = torch.tensor(result.tokens)
 
 
             if no_speech_threshold is not None:
             if no_speech_threshold is not None:
@@ -191,29 +206,36 @@ def transcribe(
                     should_skip = False
                     should_skip = False
 
 
                 if should_skip:
                 if should_skip:
-                    seek += segment.shape[-1]  # fast-forward to the next segment boundary
+                    seek += segment_size  # fast-forward to the next segment boundary
                     continue
                     continue
 
 
+            previous_seek = seek
+            current_segments = []
+            current_tokens = []
+
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
             consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
             consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
             if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
             if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
                 if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
                 if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
                     consecutive = consecutive.tolist() + [len(tokens)]
                     consecutive = consecutive.tolist() + [len(tokens)]
+
                 last_slice = 0
                 last_slice = 0
                 for current_slice in consecutive:
                 for current_slice in consecutive:
                     sliced_tokens = tokens[last_slice:current_slice]
                     sliced_tokens = tokens[last_slice:current_slice]
                     start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
                     start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
                     end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
                     end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
-                    add_segment(
-                        start=timestamp_offset + start_timestamp_pos * time_precision,
-                        end=timestamp_offset + end_timestamp_pos * time_precision,
-                        text_tokens=sliced_tokens[1:-1],
+                    current_segments.append(new_segment(
+                        start=time_offset + start_timestamp_pos * time_precision,
+                        end=time_offset + end_timestamp_pos * time_precision,
+                        tokens=sliced_tokens,
                         result=result,
                         result=result,
-                    )
+                    ))
+                    current_tokens.append(sliced_tokens.tolist())
                     last_slice = current_slice
                     last_slice = current_slice
+
                 if ended_with_single_timestamp:
                 if ended_with_single_timestamp:
                     # single timestamp at the end means no speech after the last timestamp.
                     # single timestamp at the end means no speech after the last timestamp.
-                    seek += segment.shape[-1]
+                    seek += segment_size
                 else:
                 else:
                     # otherwise, ignore the unfinished segment and seek to the last timestamp
                     # otherwise, ignore the unfinished segment and seek to the last timestamp
                     last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                     last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
@@ -227,23 +249,54 @@ def transcribe(
                     last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
                     last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
                     duration = last_timestamp_pos * time_precision
                     duration = last_timestamp_pos * time_precision
 
 
-                add_segment(
-                    start=timestamp_offset,
-                    end=timestamp_offset + duration,
-                    text_tokens=tokens,
+                current_segments.append(new_segment(
+                    start=time_offset,
+                    end=time_offset + duration,
+                    tokens=tokens,
                     result=result,
                     result=result,
-                )
-
-                seek += segment.shape[-1]
-                all_tokens.extend(tokens.tolist())
+                ))
+                current_tokens.append(tokens.tolist())
+                seek += segment_size
 
 
             if not condition_on_previous_text or result.temperature > 0.5:
             if not condition_on_previous_text or result.temperature > 0.5:
                 # do not feed the prompt tokens if a high temperature was used
                 # do not feed the prompt tokens if a high temperature was used
                 prompt_reset_since = len(all_tokens)
                 prompt_reset_since = len(all_tokens)
 
 
+            if word_timestamps:
+                add_word_timestamps(
+                    segments=current_segments,
+                    model=model,
+                    tokenizer=tokenizer,
+                    mel=mel_segment,
+                    num_frames=segment_size,
+                    prepend_punctuations=prepend_punctuations,
+                    append_punctuations=append_punctuations,
+                )
+                word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
+                if len(consecutive) > 0 and len(word_end_timestamps) > 0:
+                    seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
+                    if seek_shift > 0:
+                        seek = previous_seek + seek_shift
+
+            if verbose:
+                for segment in current_segments:
+                    start, end, text = segment["start"], segment["end"], segment["text"]
+                    line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
+                    print(make_safe(line))
+
+            # if a segment is instantaneous or does not contain text, clear it
+            for i, segment in enumerate(current_segments):
+                if segment["start"] == segment["end"] or segment["text"].strip() == "":
+                    segment["text"] = ""
+                    segment["tokens"] = []
+                    segment["words"] = []
+                    current_tokens[i] = []
+
+            all_segments.extend(current_segments)
+            all_tokens.extend([token for segment in current_tokens for token in segment])
+
             # update progress bar
             # update progress bar
-            pbar.update(min(num_frames, seek) - previous_seek_value)
-            previous_seek_value = seek
+            pbar.update(min(num_frames, seek) - previous_seek)
 
 
     return dict(
     return dict(
         text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
         text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
@@ -282,6 +335,9 @@ def cli():
     parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
     parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
     parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
     parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
     parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
     parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
+    parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
+    parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
+    parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
     parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
     parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
 
 
     args = parser.parse_args().__dict__
     args = parser.parse_args().__dict__

+ 92 - 0
whisper/triton_ops.py

@@ -0,0 +1,92 @@
+import math
+
+import numpy as np
+import torch
+from functools import lru_cache
+
+try:
+    import triton
+    import triton.language as tl
+except ImportError:
+    raise RuntimeError("triton import failed; try `pip install --pre triton`")
+
+
+@triton.jit
+def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
+    offsets = tl.arange(0, BLOCK_SIZE)
+    mask = offsets < M
+
+    for k in range(1, N + M + 1):  # k = i + j
+        tl.debug_barrier()
+
+        p0 = cost + (k - 1) * cost_stride
+        p1 = cost + k * cost_stride
+        p2 = cost + k * cost_stride + 1
+
+        c0 = tl.load(p0 + offsets, mask=mask)
+        c1 = tl.load(p1 + offsets, mask=mask)
+        c2 = tl.load(p2 + offsets, mask=mask)
+
+        x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
+        cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
+
+        cost_ptr = cost + (k + 1) * cost_stride + 1
+        tl.store(cost_ptr + offsets, cost_row, mask=mask)
+
+        trace_ptr = trace + (k + 1) * trace_stride + 1
+        tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
+        tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
+        tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
+
+
+@lru_cache(maxsize=None)
+def median_kernel(filter_width: int):
+    @triton.jit
+    def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr):  # x.shape[-1] == filter_width
+        row_idx = tl.program_id(0)
+        offsets = tl.arange(0, BLOCK_SIZE)
+        mask = offsets < y_stride
+
+        x_ptr = x + row_idx * x_stride
+        y_ptr = y + row_idx * y_stride
+
+        LOAD_ALL_ROWS_HERE
+
+        BUBBLESORT_HERE
+
+        tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
+
+    kernel = triton.JITFunction(kernel.fn)
+    kernel.src = kernel.src.replace("    LOAD_ALL_ROWS_HERE", "\n".join([
+        f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
+        for i in range(filter_width)
+    ]))
+    kernel.src = kernel.src.replace("    BUBBLESORT_HERE", "\n\n".join([
+        "\n\n".join([
+            "\n".join([
+                f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
+                f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
+                f"    row{j} = smaller",
+                f"    row{j + 1} = larger",
+            ])
+            for j in range(filter_width - i - 1)
+        ])
+        for i in range(filter_width // 2 + 1)
+    ]))
+    kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
+
+    return kernel
+
+
+def median_filter_cuda(x: torch.Tensor, filter_width: int):
+    """Apply a median filter of given width along the last dimension of x"""
+    slices = x.contiguous().unfold(-1, filter_width, 1)
+    grid = np.prod(slices.shape[:-2])
+
+    kernel = median_kernel(filter_width)
+    y = torch.empty_like(slices[..., 0])
+
+    BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
+    kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
+
+    return y

+ 48 - 19
whisper/utils.py

@@ -85,34 +85,63 @@ class WriteTXT(ResultWriter):
             print(segment['text'].strip(), file=file, flush=True)
             print(segment['text'].strip(), file=file, flush=True)
 
 
 
 
-class WriteVTT(ResultWriter):
+class SubtitlesWriter(ResultWriter):
+    always_include_hours: bool
+    decimal_marker: str
+
+    def iterate_result(self, result: dict):
+        for segment in result["segments"]:
+            segment_start = self.format_timestamp(segment["start"])
+            segment_end = self.format_timestamp(segment["end"])
+            segment_text = segment['text'].strip().replace('-->', '->')
+
+            if word_timings := segment.get("words", None):
+                all_words = [timing["word"] for timing in word_timings]
+                all_words[0] = all_words[0].strip()  # remove the leading space, if any
+                last = segment_start
+                for i, this_word in enumerate(word_timings):
+                    start = self.format_timestamp(this_word["start"])
+                    end = self.format_timestamp(this_word["end"])
+                    if last != start:
+                        yield last, start, segment_text
+
+                    yield start, end, "".join(
+                        [f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)]
+                    )
+                    last = end
+
+                if last != segment_end:
+                    yield last, segment_end, segment_text
+            else:
+                yield segment_start, segment_end, segment_text
+
+    def format_timestamp(self, seconds: float):
+        return format_timestamp(
+            seconds=seconds,
+            always_include_hours=self.always_include_hours,
+            decimal_marker=self.decimal_marker,
+        )
+
+
+class WriteVTT(SubtitlesWriter):
     extension: str = "vtt"
     extension: str = "vtt"
+    always_include_hours: bool = False
+    decimal_marker: str = '.'
 
 
     def write_result(self, result: dict, file: TextIO):
     def write_result(self, result: dict, file: TextIO):
         print("WEBVTT\n", file=file)
         print("WEBVTT\n", file=file)
-        for segment in result["segments"]:
-            print(
-                f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
-                f"{segment['text'].strip().replace('-->', '->')}\n",
-                file=file,
-                flush=True,
-            )
+        for start, end, text in self.iterate_result(result):
+            print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
 
 
 
 
-class WriteSRT(ResultWriter):
+class WriteSRT(SubtitlesWriter):
     extension: str = "srt"
     extension: str = "srt"
+    always_include_hours: bool = True
+    decimal_marker: str = ','
 
 
     def write_result(self, result: dict, file: TextIO):
     def write_result(self, result: dict, file: TextIO):
-        for i, segment in enumerate(result["segments"], start=1):
-            # write srt lines
-            print(
-                f"{i}\n"
-                f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
-                f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
-                f"{segment['text'].strip().replace('-->', '->')}\n",
-                file=file,
-                flush=True,
-            )
+        for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
+            print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
 
 
 
 
 class WriteTSV(ResultWriter):
 class WriteTSV(ResultWriter):