Преглед на файлове

word-level timestamps in `transcribe()` (#869)

* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <ryan@ryanheise.com>
Jong Wook Kim преди 1 година
родител
ревизия
500d0fe966
променени са 14 файла, в които са добавени 769 реда и са изтрити 78 реда
  1. 2 3
      .github/workflows/test.yml
  2. 1 0
      requirements.txt
  3. 18 2
      setup.py
  4. 14 0
      tests/conftest.py
  5. 87 0
      tests/test_timing.py
  6. 13 1
      tests/test_transcribe.py
  7. 22 0
      whisper/__init__.py
  8. 4 0
      whisper/audio.py
  9. 12 1
      whisper/model.py
  10. 305 0
      whisper/timing.py
  11. 43 0
      whisper/tokenizer.py
  12. 108 52
      whisper/transcribe.py
  13. 92 0
      whisper/triton_ops.py
  14. 48 19
      whisper/utils.py

+ 2 - 3
.github/workflows/test.yml

@@ -21,6 +21,5 @@ jobs:
       - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
       - uses: actions/checkout@v2
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
-      - run: pip install pytest
-      - run: pip install .
-      - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
+      - run: pip install .["dev"]
+      - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

+ 1 - 0
requirements.txt

@@ -1,3 +1,4 @@
+numba
 numpy
 torch
 tqdm

+ 18 - 2
setup.py

@@ -1,4 +1,5 @@
 import os
+import sys
 
 import pkg_resources
 from setuptools import setup, find_packages
@@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
     return locals()["__version__"]
 
 
+requirements = []
+if sys.platform.startswith("linux"):
+    triton_requirement = "triton>=2.0.0.dev20221202"
+    try:
+        import re
+        import subprocess
+        version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
+        major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
+        if (int(major), int(minor)) < (11, 4):
+            # the last version supporting CUDA < 11.4
+            triton_requirement = "triton==2.0.0.dev20221011"
+    except (IndexError, OSError, subprocess.SubprocessError):
+        pass
+    requirements.append(triton_requirement)
+
 setup(
     name="openai-whisper",
     py_modules=["whisper"],
@@ -22,7 +38,7 @@ setup(
     url="https://github.com/openai/whisper",
     license="MIT",
     packages=find_packages(exclude=["tests*"]),
-    install_requires=[
+    install_requires=requirements + [
         str(r)
         for r in pkg_resources.parse_requirements(
             open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -32,5 +48,5 @@ setup(
         "console_scripts": ["whisper=whisper.transcribe:cli"],
     },
     include_package_data=True,
-    extras_require={"dev": ["pytest"]},
+    extras_require={"dev": ["pytest", "scipy"]},
 )

+ 14 - 0
tests/conftest.py

@@ -0,0 +1,14 @@
+import random as rand
+
+import numpy
+import pytest
+
+
+def pytest_configure(config):
+    config.addinivalue_line("markers", "requires_cuda")
+
+
+@pytest.fixture
+def random():
+    rand.seed(42)
+    numpy.random.seed(42)

+ 87 - 0
tests/test_timing.py

@@ -0,0 +1,87 @@
+import pytest
+import numpy as np
+import scipy.ndimage
+import torch
+
+from whisper.timing import dtw_cpu, dtw_cuda, median_filter
+
+
+sizes = [
+    (10, 20), (32, 16), (123, 1500), (234, 189),
+]
+shapes = [
+    (10,), (1, 15),  (4, 5, 345), (6, 12, 240, 512),
+]
+
+
+@pytest.mark.parametrize("N, M", sizes)
+def test_dtw(N: int, M: int):
+    steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
+    np.random.shuffle(steps)
+    x = np.random.random((N, M)).astype(np.float32)
+
+    i, j, k = 0, 0, 0
+    trace = []
+    while True:
+        x[i, j] -= 1
+        trace.append((i, j))
+
+        if k == len(steps):
+            break
+
+        if k + 1 < len(steps) and steps[k] != steps[k + 1]:
+            i += 1
+            j += 1
+            k += 2
+            continue
+
+        if steps[k] == 0:
+            i += 1
+        if steps[k] == 1:
+            j += 1
+        k += 1
+
+    trace = np.array(trace).T
+    dtw_trace = dtw_cpu(x)
+
+    assert np.allclose(trace, dtw_trace)
+
+
+@pytest.mark.requires_cuda
+@pytest.mark.parametrize("N, M", sizes)
+def test_dtw_cuda_equivalence(N: int, M: int):
+    x_numpy = np.random.randn(N, M).astype(np.float32)
+    x_cuda = torch.from_numpy(x_numpy).cuda()
+
+    trace_cpu = dtw_cpu(x_numpy)
+    trace_cuda = dtw_cuda(x_cuda)
+
+    assert np.allclose(trace_cpu, trace_cuda)
+
+
+@pytest.mark.parametrize("shape", shapes)
+def test_median_filter(shape):
+    x = torch.randn(*shape)
+
+    for filter_width in [3, 5, 7, 13]:
+        filtered = median_filter(x, filter_width)
+
+        # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
+        pad_width = filter_width // 2
+        padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
+        scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
+        scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
+
+        assert np.allclose(filtered, scipy_filtered)
+
+
+@pytest.mark.requires_cuda
+@pytest.mark.parametrize("shape", shapes)
+def test_median_filter_equivalence(shape):
+    x = torch.randn(*shape)
+
+    for filter_width in [3, 5, 7, 13]:
+        filtered_cpu = median_filter(x, filter_width)
+        filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
+
+        assert np.allclose(filtered_cpu, filtered_gpu)

+ 13 - 1
tests/test_transcribe.py

@@ -13,10 +13,22 @@ def test_transcribe(model_name: str):
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
 
     language = "en" if model_name.endswith(".en") else None
-    result = model.transcribe(audio_path, language=language, temperature=0.0)
+    result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
     assert result["language"] == "en"
 
     transcription = result["text"].lower()
     assert "my fellow americans" in transcription
     assert "your country" in transcription
     assert "do for you" in transcription
+
+    timing_checked = False
+    for segment in result["segments"]:
+        for timing in segment["words"]:
+            assert timing["start"] < timing["end"]
+            if timing["word"].strip(" ,") == "Americans":
+                assert timing["start"] <= 1.8
+                assert timing["end"] >= 1.8
+                print(timing)
+                timing_checked = True
+
+    assert timing_checked

+ 22 - 0
whisper/__init__.py

@@ -29,6 +29,23 @@ _MODELS = {
     "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
 }
 
+# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
+# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
+_ALIGNMENT_HEADS = {
+    "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
+    "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
+    "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
+    "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
+    "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
+    "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
+    "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
+    "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
+    "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
+    "large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
+    "large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
+}
+
+
 
 def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
     os.makedirs(root, exist_ok=True)
@@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
 
     if name in _MODELS:
         checkpoint_file = _download(_MODELS[name], download_root, in_memory)
+        alignment_heads = _ALIGNMENT_HEADS[name]
     elif os.path.isfile(name):
         checkpoint_file = open(name, "rb").read() if in_memory else name
+        alignment_heads = None
     else:
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
 
@@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
     model = Whisper(dims)
     model.load_state_dict(checkpoint["model_state_dict"])
 
+    if alignment_heads is not None:
+        model.set_alignment_heads(alignment_heads)
+
     return model.to(device)

+ 4 - 0
whisper/audio.py

@@ -18,6 +18,10 @@ CHUNK_LENGTH = 30
 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
 N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000: number of frames in a mel spectrogram input
 
+N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 100 mel frames in 1s (10ms each)
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 50 audio tokens in 1s (20ms each)
+
 
 def load_audio(file: str, sr: int = SAMPLE_RATE):
     """

+ 12 - 1
whisper/model.py

@@ -1,3 +1,5 @@
+import base64
+import gzip
 from dataclasses import dataclass
 from typing import Dict
 from typing import Iterable, Optional
@@ -8,8 +10,8 @@ import torch.nn.functional as F
 from torch import Tensor
 from torch import nn
 
-from .transcribe import transcribe as transcribe_function
 from .decoding import detect_language as detect_language_function, decode as decode_function
+from .transcribe import transcribe as transcribe_function
 
 
 @dataclass
@@ -213,6 +215,15 @@ class Whisper(nn.Module):
             self.dims.n_text_head,
             self.dims.n_text_layer,
         )
+        # use the last half layers for alignment by default; see `set_alignment_heads()` below
+        all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
+        all_heads[self.dims.n_text_layer // 2:] = True
+        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+
+    def set_alignment_heads(self, dump: bytes):
+        array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
+        mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
+        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
 
     def embed_audio(self, mel: torch.Tensor):
         return self.encoder(mel)

+ 305 - 0
whisper/timing.py

@@ -0,0 +1,305 @@
+import subprocess
+import warnings
+from dataclasses import dataclass
+from typing import List, TYPE_CHECKING
+
+import numba
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
+from .tokenizer import Tokenizer
+
+if TYPE_CHECKING:
+    from .model import Whisper
+
+
+def median_filter(x: torch.Tensor, filter_width: int):
+    """Apply a median filter of width `filter_width` along the last dimension of `x`"""
+    pad_width = filter_width // 2
+    if x.shape[-1] <= pad_width:
+        # F.pad requires the padding width to be smaller than the input dimension
+        return x
+
+    if (ndim := x.ndim) <= 2:
+        # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
+        x = x[None, None, :]
+
+    assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
+
+    result = None
+    x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
+    if x.is_cuda:
+        try:
+            from .triton_ops import median_filter_cuda
+            result = median_filter_cuda(x, filter_width)
+        except (RuntimeError, subprocess.CalledProcessError):
+            warnings.warn(
+                "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
+                "falling back to a slower median kernel implementation..."
+            )
+
+    if result is None:
+        # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
+        result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
+
+    if ndim <= 2:
+        result = result[0, 0]
+
+    return result
+
+@numba.jit
+def backtrace(trace: np.ndarray):
+    i = trace.shape[0] - 1
+    j = trace.shape[1] - 1
+    trace[0, :] = 2
+    trace[:, 0] = 1
+
+    result = []
+    while i > 0 or j > 0:
+        result.append((i - 1, j - 1))
+
+        if trace[i, j] == 0:
+            i -= 1
+            j -= 1
+        elif trace[i, j] == 1:
+            i -= 1
+        elif trace[i, j] == 2:
+            j -= 1
+        else:
+            raise ValueError("Unexpected trace[i, j]")
+
+    result = np.array(result)
+    return result[::-1, :].T
+
+
+@numba.jit(nopython=True, parallel=True)
+def dtw_cpu(x: np.ndarray):
+    N, M = x.shape
+    cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
+    trace = -np.ones((N + 1, M + 1), dtype=np.float32)
+
+    cost[0, 0] = 0
+    for j in range(1, M + 1):
+        for i in range(1, N + 1):
+            c0 = cost[i - 1, j - 1]
+            c1 = cost[i - 1, j]
+            c2 = cost[i, j - 1]
+
+            if c0 < c1 and c0 < c2:
+                c, t = c0, 0
+            elif c1 < c0 and c1 < c2:
+                c, t = c1, 1
+            else:
+                c, t = c2, 2
+
+            cost[i, j] = x[i - 1, j - 1] + c
+            trace[i, j] = t
+
+    return backtrace(trace)
+
+
+def dtw_cuda(x, BLOCK_SIZE=1024):
+    from .triton_ops import dtw_kernel
+
+    M, N = x.shape
+    assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
+
+    x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
+    x_skew = x_skew.T.contiguous()
+    cost = torch.ones(N + M + 2, M + 2) * np.inf
+    cost[0, 0] = 0
+    cost = cost.cuda()
+    trace = torch.zeros_like(cost, dtype=torch.int32)
+
+    dtw_kernel[(1,)](
+        cost,
+        trace,
+        x_skew,
+        x_skew.stride(0),
+        cost.stride(0),
+        trace.stride(0),
+        N,
+        M,
+        BLOCK_SIZE=BLOCK_SIZE
+    )
+
+    trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
+    return backtrace(trace.cpu().numpy())
+
+
+def dtw(x: torch.Tensor) -> np.ndarray:
+    if x.is_cuda:
+        try:
+            return dtw_cuda(x)
+        except (RuntimeError, subprocess.CalledProcessError):
+            warnings.warn(
+                "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
+                "falling back to a slower DTW implementation..."
+            )
+
+    return dtw_cpu(x.double().cpu().numpy())
+
+
+@dataclass
+class WordTiming:
+    word: str
+    tokens: List[int]
+    start: float
+    end: float
+    probability: float
+
+
+def find_alignment(
+    model: "Whisper",
+    tokenizer: Tokenizer,
+    text_tokens: List[int],
+    mel: torch.Tensor,
+    num_frames: int,
+    *,
+    medfilt_width: int = 7,
+    qk_scale: float = 1.0,
+) -> List[WordTiming]:
+    tokens = torch.tensor(
+        [
+            *tokenizer.sot_sequence,
+            tokenizer.no_timestamps,
+            *text_tokens,
+            tokenizer.eot,
+        ]
+    ).to(model.device)
+
+    # install hooks on the cross attention layers to retrieve the attention weights
+    QKs = [None] * model.dims.n_text_layer
+    hooks = [
+        block.cross_attn.register_forward_hook(
+            lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
+        )
+        for i, block in enumerate(model.decoder.blocks)
+    ]
+
+    with torch.no_grad():
+        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
+        token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
+        text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
+
+    for hook in hooks:
+        hook.remove()
+
+    # heads * tokens * frames
+    weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
+    weights = weights[:, :, : num_frames // 2]
+    weights = (weights * qk_scale).softmax(dim=-1)
+    std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
+    weights = (weights - mean) / std
+    weights = median_filter(weights, medfilt_width)
+
+    matrix = weights.mean(axis=0)
+    matrix = matrix[len(tokenizer.sot_sequence):-1]
+    text_indices, time_indices = dtw(-matrix)
+
+    words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
+    word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
+
+    jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
+    jump_times = time_indices[jumps] / TOKENS_PER_SECOND
+    start_times = jump_times[word_boundaries[:-1]]
+    end_times = jump_times[word_boundaries[1:]]
+    word_probabilities = [
+        np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
+    ]
+
+    # hack: ensure the first and second word is not longer than twice the median word duration.
+    # a better segmentation algorithm based on VAD should be able to replace this.
+    word_durations = end_times - start_times
+    word_durations = word_durations[word_durations.nonzero()]
+    if len(word_durations) > 0:
+        median_duration = np.median(word_durations)
+        max_duration = median_duration * 2
+        if len(word_durations) >= 2 and word_durations[1] > max_duration:
+            end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
+        if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
+            start_times[0] = max(0, end_times[0] - max_duration)
+
+    return [
+        WordTiming(word, tokens, start, end, probability)
+        for word, tokens, start, end, probability in zip(
+            words, word_tokens, start_times, end_times, word_probabilities
+        )
+    ]
+
+
+def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
+    # merge prepended punctuations
+    i = len(alignment) - 2
+    j = len(alignment) - 1
+    while i >= 0:
+        previous = alignment[i]
+        following = alignment[j]
+        if previous.word.startswith(" ") and previous.word.strip() in prepended:
+            # prepend it to the following word
+            following.word = previous.word + following.word
+            following.tokens = previous.tokens + following.tokens
+            previous.word = ""
+            previous.tokens = []
+        else:
+            j = i
+        i -= 1
+
+    # merge appended punctuations
+    i = 0
+    j = 1
+    while j < len(alignment):
+        previous = alignment[i]
+        following = alignment[j]
+        if not previous.word.endswith(" ") and following.word in appended:
+            # append it to the previous word
+            previous.word = previous.word + following.word
+            previous.tokens = previous.tokens + following.tokens
+            following.word = ""
+            following.tokens = []
+        else:
+            i = j
+        j += 1
+
+
+def add_word_timestamps(
+    *,
+    segments: List[dict],
+    model: "Whisper",
+    tokenizer: Tokenizer,
+    mel: torch.Tensor,
+    num_frames: int,
+    prepend_punctuations: str = "\"\'“¿([{-",
+    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
+    **hyperparams,
+):
+    if len(segments) == 0:
+        return
+
+    text_tokens = [t for segment in segments for t in segment["tokens"]]
+    alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
+    merge_punctuations(alignment, prepend_punctuations, append_punctuations)
+
+    time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
+    token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
+
+    for segment in segments:
+        segment["words"] = []
+
+    word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
+    for i, timing in enumerate(alignment):
+        if timing.word:
+            segment = segments[token_sources[word_boundaries[i]]]
+            start = round(time_offset + timing.start, 2)
+            end = round(time_offset + timing.end, 2)
+            segment["words"].append(
+                dict(word=timing.word, start=start, end=end, probability=timing.probability)
+            )
+
+    for segment in segments:
+        if len(words := segment["words"]) > 0:
+            # adjust the segment-level timestamps based on the word-level timestamps
+            segment["start"] = words[0]["start"]
+            segment["end"] = words[-1]["end"]

+ 43 - 0
whisper/tokenizer.py

@@ -1,4 +1,5 @@
 import os
+import string
 from dataclasses import dataclass
 from functools import lru_cache, cached_property
 from typing import List, Optional, Tuple, Union
@@ -265,6 +266,48 @@ class Tokenizer:
         assert len(tokens) == 1, f"{text} is not encoded as a single token"
         return tokens[0]
 
+    def split_to_word_tokens(self, tokens: List[int]):
+        if self.language in {"zh", "ja", "th", "lo", "my"}:
+            # These languages don't typically use spaces, so it is difficult to split words
+            # without morpheme analysis. Here, we instead split words at any
+            # position where the tokens are decoded as valid unicode points
+            return self.split_tokens_on_unicode(tokens)
+
+        return self.split_tokens_on_spaces(tokens)
+
+    def split_tokens_on_unicode(self, tokens: List[int]):
+        words = []
+        word_tokens = []
+        current_tokens = []
+
+        for token in tokens:
+            current_tokens.append(token)
+            decoded = self.decode_with_timestamps(current_tokens)
+            if "\ufffd" not in decoded:
+                words.append(decoded)
+                word_tokens.append(current_tokens)
+                current_tokens = []
+
+        return words, word_tokens
+
+    def split_tokens_on_spaces(self, tokens: List[int]):
+        subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
+        words = []
+        word_tokens = []
+
+        for subword, subword_tokens in zip(subwords, subword_tokens_list):
+            special = subword_tokens[0] >= self.eot
+            with_space = subword.startswith(" ")
+            punctuation = subword.strip() in string.punctuation
+            if special or with_space or punctuation or len(words) == 0:
+                words.append(subword)
+                word_tokens.append(subword_tokens)
+            else:
+                words[-1] = words[-1] + subword
+                word_tokens[-1].extend(subword_tokens)
+
+        return words, word_tokens
+
 
 @lru_cache(maxsize=None)
 def build_tokenizer(name: str = "gpt2"):

+ 108 - 52
whisper/transcribe.py

@@ -7,8 +7,9 @@ import numpy as np
 import torch
 import tqdm
 
-from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
+from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
 from .decoding import DecodingOptions, DecodingResult
+from .timing import add_word_timestamps
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
 from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
 
@@ -27,6 +28,9 @@ def transcribe(
     no_speech_threshold: Optional[float] = 0.6,
     condition_on_previous_text: bool = True,
     initial_prompt: Optional[str] = None,
+    word_timestamps: bool = False,
+    prepend_punctuations: str = "\"\'“¿([{-",
+    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
     **decode_options,
 ):
     """
@@ -63,6 +67,21 @@ def transcribe(
         disabling may make the text inconsistent across windows, but the model becomes less prone to
         getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
 
+    word_timestamps: bool
+        Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
+        and include the timestamps for each word in each segment.
+
+    prepend_punctuations: str
+        If word_timestamps is True, merge these punctuation symbols with the next word
+
+    append_punctuations: str
+        If word_timestamps is True, merge these punctuation symbols with the previous word
+
+    initial_prompt: Optional[str]
+        Optional text to provide as a prompt for the first window. This can be used to provide, or
+        "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
+        to make it more likely to predict those word correctly.
+
     decode_options: dict
         Keyword arguments to construct `DecodingOptions` instances
 
@@ -90,16 +109,19 @@ def transcribe(
         else:
             if verbose:
                 print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
-            segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
-            _, probs = model.detect_language(segment)
+            mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
+            _, probs = model.detect_language(mel_segment)
             decode_options["language"] = max(probs, key=probs.get)
             if verbose is not None:
                 print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
 
-    language = decode_options["language"]
-    task = decode_options.get("task", "transcribe")
+    language: str = decode_options["language"]
+    task: str = decode_options.get("task", "transcribe")
     tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
 
+    if word_timestamps and task == "translate":
+        warnings.warn("Word-level timestamps on translations may not be reliable.")
+
     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
         temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
         decode_result = None
@@ -145,42 +167,35 @@ def transcribe(
     else:
         initial_prompt_tokens = []
 
-    def add_segment(
-        *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
+    def new_segment(
+        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
     ):
-        text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
-        if len(text.strip()) == 0:  # skip empty text output
-            return
-
-        all_segments.append(
-            {
-                "id": len(all_segments),
-                "seek": seek,
-                "start": start,
-                "end": end,
-                "text": text,
-                "tokens": text_tokens.tolist(),
-                "temperature": result.temperature,
-                "avg_logprob": result.avg_logprob,
-                "compression_ratio": result.compression_ratio,
-                "no_speech_prob": result.no_speech_prob,
-            }
-        )
-        if verbose:
-            print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
-
-    # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
+        text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
+        return {
+            "id": len(all_segments),
+            "seek": seek,
+            "start": start,
+            "end": end,
+            "text": tokenizer.decode(text_tokens),
+            "tokens": text_tokens,
+            "temperature": result.temperature,
+            "avg_logprob": result.avg_logprob,
+            "compression_ratio": result.compression_ratio,
+            "no_speech_prob": result.no_speech_prob,
+        }
+
+    # show the progress bar when verbose is False (if True, transcribed text will be printed)
     num_frames = mel.shape[-1]
-    previous_seek_value = seek
-
     with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
         while seek < num_frames:
-            timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
-            segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
-            segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
+            time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
+            mel_segment = mel[:, seek:]
+            segment_size = min(mel_segment.shape[-1], N_FRAMES)
+            segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
+            mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
 
             decode_options["prompt"] = all_tokens[prompt_reset_since:]
-            result: DecodingResult = decode_with_fallback(segment)
+            result: DecodingResult = decode_with_fallback(mel_segment)
             tokens = torch.tensor(result.tokens)
 
             if no_speech_threshold is not None:
@@ -191,29 +206,36 @@ def transcribe(
                     should_skip = False
 
                 if should_skip:
-                    seek += segment.shape[-1]  # fast-forward to the next segment boundary
+                    seek += segment_size  # fast-forward to the next segment boundary
                     continue
 
+            previous_seek = seek
+            current_segments = []
+            current_tokens = []
+
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
             consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
             if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
                 if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
                     consecutive = consecutive.tolist() + [len(tokens)]
+
                 last_slice = 0
                 for current_slice in consecutive:
                     sliced_tokens = tokens[last_slice:current_slice]
                     start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
                     end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
-                    add_segment(
-                        start=timestamp_offset + start_timestamp_pos * time_precision,
-                        end=timestamp_offset + end_timestamp_pos * time_precision,
-                        text_tokens=sliced_tokens[1:-1],
+                    current_segments.append(new_segment(
+                        start=time_offset + start_timestamp_pos * time_precision,
+                        end=time_offset + end_timestamp_pos * time_precision,
+                        tokens=sliced_tokens,
                         result=result,
-                    )
+                    ))
+                    current_tokens.append(sliced_tokens.tolist())
                     last_slice = current_slice
+
                 if ended_with_single_timestamp:
                     # single timestamp at the end means no speech after the last timestamp.
-                    seek += segment.shape[-1]
+                    seek += segment_size
                 else:
                     # otherwise, ignore the unfinished segment and seek to the last timestamp
                     last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
@@ -227,23 +249,54 @@ def transcribe(
                     last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
                     duration = last_timestamp_pos * time_precision
 
-                add_segment(
-                    start=timestamp_offset,
-                    end=timestamp_offset + duration,
-                    text_tokens=tokens,
+                current_segments.append(new_segment(
+                    start=time_offset,
+                    end=time_offset + duration,
+                    tokens=tokens,
                     result=result,
-                )
-
-                seek += segment.shape[-1]
-                all_tokens.extend(tokens.tolist())
+                ))
+                current_tokens.append(tokens.tolist())
+                seek += segment_size
 
             if not condition_on_previous_text or result.temperature > 0.5:
                 # do not feed the prompt tokens if a high temperature was used
                 prompt_reset_since = len(all_tokens)
 
+            if word_timestamps:
+                add_word_timestamps(
+                    segments=current_segments,
+                    model=model,
+                    tokenizer=tokenizer,
+                    mel=mel_segment,
+                    num_frames=segment_size,
+                    prepend_punctuations=prepend_punctuations,
+                    append_punctuations=append_punctuations,
+                )
+                word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
+                if len(consecutive) > 0 and len(word_end_timestamps) > 0:
+                    seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
+                    if seek_shift > 0:
+                        seek = previous_seek + seek_shift
+
+            if verbose:
+                for segment in current_segments:
+                    start, end, text = segment["start"], segment["end"], segment["text"]
+                    line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
+                    print(make_safe(line))
+
+            # if a segment is instantaneous or does not contain text, clear it
+            for i, segment in enumerate(current_segments):
+                if segment["start"] == segment["end"] or segment["text"].strip() == "":
+                    segment["text"] = ""
+                    segment["tokens"] = []
+                    segment["words"] = []
+                    current_tokens[i] = []
+
+            all_segments.extend(current_segments)
+            all_tokens.extend([token for segment in current_tokens for token in segment])
+
             # update progress bar
-            pbar.update(min(num_frames, seek) - previous_seek_value)
-            previous_seek_value = seek
+            pbar.update(min(num_frames, seek) - previous_seek)
 
     return dict(
         text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
@@ -282,6 +335,9 @@ def cli():
     parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
     parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
     parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
+    parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
+    parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
+    parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
     parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
 
     args = parser.parse_args().__dict__

+ 92 - 0
whisper/triton_ops.py

@@ -0,0 +1,92 @@
+import math
+
+import numpy as np
+import torch
+from functools import lru_cache
+
+try:
+    import triton
+    import triton.language as tl
+except ImportError:
+    raise RuntimeError("triton import failed; try `pip install --pre triton`")
+
+
+@triton.jit
+def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
+    offsets = tl.arange(0, BLOCK_SIZE)
+    mask = offsets < M
+
+    for k in range(1, N + M + 1):  # k = i + j
+        tl.debug_barrier()
+
+        p0 = cost + (k - 1) * cost_stride
+        p1 = cost + k * cost_stride
+        p2 = cost + k * cost_stride + 1
+
+        c0 = tl.load(p0 + offsets, mask=mask)
+        c1 = tl.load(p1 + offsets, mask=mask)
+        c2 = tl.load(p2 + offsets, mask=mask)
+
+        x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
+        cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
+
+        cost_ptr = cost + (k + 1) * cost_stride + 1
+        tl.store(cost_ptr + offsets, cost_row, mask=mask)
+
+        trace_ptr = trace + (k + 1) * trace_stride + 1
+        tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
+        tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
+        tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
+
+
+@lru_cache(maxsize=None)
+def median_kernel(filter_width: int):
+    @triton.jit
+    def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr):  # x.shape[-1] == filter_width
+        row_idx = tl.program_id(0)
+        offsets = tl.arange(0, BLOCK_SIZE)
+        mask = offsets < y_stride
+
+        x_ptr = x + row_idx * x_stride
+        y_ptr = y + row_idx * y_stride
+
+        LOAD_ALL_ROWS_HERE
+
+        BUBBLESORT_HERE
+
+        tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
+
+    kernel = triton.JITFunction(kernel.fn)
+    kernel.src = kernel.src.replace("    LOAD_ALL_ROWS_HERE", "\n".join([
+        f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
+        for i in range(filter_width)
+    ]))
+    kernel.src = kernel.src.replace("    BUBBLESORT_HERE", "\n\n".join([
+        "\n\n".join([
+            "\n".join([
+                f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
+                f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
+                f"    row{j} = smaller",
+                f"    row{j + 1} = larger",
+            ])
+            for j in range(filter_width - i - 1)
+        ])
+        for i in range(filter_width // 2 + 1)
+    ]))
+    kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
+
+    return kernel
+
+
+def median_filter_cuda(x: torch.Tensor, filter_width: int):
+    """Apply a median filter of given width along the last dimension of x"""
+    slices = x.contiguous().unfold(-1, filter_width, 1)
+    grid = np.prod(slices.shape[:-2])
+
+    kernel = median_kernel(filter_width)
+    y = torch.empty_like(slices[..., 0])
+
+    BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
+    kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
+
+    return y

+ 48 - 19
whisper/utils.py

@@ -85,34 +85,63 @@ class WriteTXT(ResultWriter):
             print(segment['text'].strip(), file=file, flush=True)
 
 
-class WriteVTT(ResultWriter):
+class SubtitlesWriter(ResultWriter):
+    always_include_hours: bool
+    decimal_marker: str
+
+    def iterate_result(self, result: dict):
+        for segment in result["segments"]:
+            segment_start = self.format_timestamp(segment["start"])
+            segment_end = self.format_timestamp(segment["end"])
+            segment_text = segment['text'].strip().replace('-->', '->')
+
+            if word_timings := segment.get("words", None):
+                all_words = [timing["word"] for timing in word_timings]
+                all_words[0] = all_words[0].strip()  # remove the leading space, if any
+                last = segment_start
+                for i, this_word in enumerate(word_timings):
+                    start = self.format_timestamp(this_word["start"])
+                    end = self.format_timestamp(this_word["end"])
+                    if last != start:
+                        yield last, start, segment_text
+
+                    yield start, end, "".join(
+                        [f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)]
+                    )
+                    last = end
+
+                if last != segment_end:
+                    yield last, segment_end, segment_text
+            else:
+                yield segment_start, segment_end, segment_text
+
+    def format_timestamp(self, seconds: float):
+        return format_timestamp(
+            seconds=seconds,
+            always_include_hours=self.always_include_hours,
+            decimal_marker=self.decimal_marker,
+        )
+
+
+class WriteVTT(SubtitlesWriter):
     extension: str = "vtt"
+    always_include_hours: bool = False
+    decimal_marker: str = '.'
 
     def write_result(self, result: dict, file: TextIO):
         print("WEBVTT\n", file=file)
-        for segment in result["segments"]:
-            print(
-                f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
-                f"{segment['text'].strip().replace('-->', '->')}\n",
-                file=file,
-                flush=True,
-            )
+        for start, end, text in self.iterate_result(result):
+            print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
 
 
-class WriteSRT(ResultWriter):
+class WriteSRT(SubtitlesWriter):
     extension: str = "srt"
+    always_include_hours: bool = True
+    decimal_marker: str = ','
 
     def write_result(self, result: dict, file: TextIO):
-        for i, segment in enumerate(result["segments"], start=1):
-            # write srt lines
-            print(
-                f"{i}\n"
-                f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
-                f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
-                f"{segment['text'].strip().replace('-->', '->')}\n",
-                file=file,
-                flush=True,
-            )
+        for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
+            print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
 
 
 class WriteTSV(ResultWriter):