Browse Source

apply formatting with `black` (#1038)

* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
Jong Wook Kim 1 year ago
parent
commit
b80bcf610d

+ 4 - 0
.flake8

@@ -0,0 +1,4 @@
+[flake8]
+per-file-ignores =
+    */__init__.py: F401
+

+ 3 - 0
.github/workflows/test.yml

@@ -22,4 +22,7 @@ jobs:
       - uses: actions/checkout@v2
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
       - run: pip install .["dev"]
+      - run: black --check --diff -t py38 --include '(\.pyi?)$' .
+      - run: isort --check --diff .
+      - run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
       - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

+ 8 - 0
pyproject.toml

@@ -0,0 +1,8 @@
+[tool.black]
+
+[tool.isort]
+profile = "black"
+include_trailing_comma = true
+line_length = 88
+multi_line_output = 3
+

+ 8 - 4
setup.py

@@ -2,7 +2,7 @@ import os
 import sys
 
 import pkg_resources
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
 
 
 def read_version(fname="whisper/version.py"):
@@ -16,7 +16,10 @@ if sys.platform.startswith("linux"):
     try:
         import re
         import subprocess
-        version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
+
+        version_line = (
+            subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
+        )
         major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
         if (int(major), int(minor)) < (11, 4):
             # the last version supporting CUDA < 11.4
@@ -38,7 +41,8 @@ setup(
     url="https://github.com/openai/whisper",
     license="MIT",
     packages=find_packages(exclude=["tests*"]),
-    install_requires=requirements + [
+    install_requires=requirements
+    + [
         str(r)
         for r in pkg_resources.parse_requirements(
             open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -48,5 +52,5 @@ setup(
         "console_scripts": ["whisper=whisper.transcribe:cli"],
     },
     include_package_data=True,
-    extras_require={"dev": ["pytest", "scipy"]},
+    extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
 )

+ 1 - 1
tests/test_audio.py

@@ -2,7 +2,7 @@ import os.path
 
 import numpy as np
 
-from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
+from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
 
 
 def test_audio():

+ 4 - 1
tests/test_normalizer.py

@@ -1,7 +1,10 @@
 import pytest
 
 from whisper.normalizers import EnglishTextNormalizer
-from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
+from whisper.normalizers.english import (
+    EnglishNumberNormalizer,
+    EnglishSpellingNormalizer,
+)
 
 
 @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])

+ 15 - 6
tests/test_timing.py

@@ -1,16 +1,21 @@
-import pytest
 import numpy as np
+import pytest
 import scipy.ndimage
 import torch
 
 from whisper.timing import dtw_cpu, dtw_cuda, median_filter
 
-
 sizes = [
-    (10, 20), (32, 16), (123, 1500), (234, 189),
+    (10, 20),
+    (32, 16),
+    (123, 1500),
+    (234, 189),
 ]
 shapes = [
-    (10,), (1, 15),  (4, 5, 345), (6, 12, 240, 512),
+    (10,),
+    (1, 15),
+    (4, 5, 345),
+    (6, 12, 240, 512),
 ]
 
 
@@ -68,8 +73,12 @@ def test_median_filter(shape):
 
         # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
         pad_width = filter_width // 2
-        padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
-        scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
+        padded_x = np.pad(
+            x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
+        )
+        scipy_filtered = scipy.ndimage.median_filter(
+            padded_x, [1] * (x.ndim - 1) + [filter_width]
+        )
         scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
 
         assert np.allclose(filtered, scipy_filtered)

+ 3 - 1
tests/test_transcribe.py

@@ -13,7 +13,9 @@ def test_transcribe(model_name: str):
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
 
     language = "en" if model_name.endswith(".en") else None
-    result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
+    result = model.transcribe(
+        audio_path, language=language, temperature=0.0, word_timestamps=True
+    )
     assert result["language"] == "en"
 
     transcription = result["text"].lower()

+ 30 - 20
whisper/__init__.py

@@ -10,11 +10,10 @@ from tqdm import tqdm
 
 from .audio import load_audio, log_mel_spectrogram, pad_or_trim
 from .decoding import DecodingOptions, DecodingResult, decode, detect_language
-from .model import Whisper, ModelDimensions
+from .model import ModelDimensions, Whisper
 from .transcribe import transcribe
 from .version import __version__
 
-
 _MODELS = {
     "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
     "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@@ -41,12 +40,11 @@ _ALIGNMENT_HEADS = {
     "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
     "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
     "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
-    "large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
-    "large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
+    "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
+    "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
 }
 
 
-
 def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
     os.makedirs(root, exist_ok=True)
 
@@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
             return model_bytes if in_memory else download_target
         else:
-            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+            warnings.warn(
+                f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
+            )
 
     with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
-        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
+        with tqdm(
+            total=int(source.info().get("Content-Length")),
+            ncols=80,
+            unit="iB",
+            unit_scale=True,
+            unit_divisor=1024,
+        ) as loop:
             while True:
                 buffer = source.read(8192)
                 if not buffer:
@@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
 
     model_bytes = open(download_target, "rb").read()
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
-        raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
+        raise RuntimeError(
+            "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
+        )
 
     return model_bytes if in_memory else download_target
 
@@ -86,7 +94,12 @@ def available_models() -> List[str]:
     return list(_MODELS.keys())
 
 
-def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
+def load_model(
+    name: str,
+    device: Optional[Union[str, torch.device]] = None,
+    download_root: str = None,
+    in_memory: bool = False,
+) -> Whisper:
     """
     Load a Whisper ASR model
 
@@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
     if device is None:
         device = "cuda" if torch.cuda.is_available() else "cpu"
     if download_root is None:
-        download_root = os.path.join(
-            os.getenv(
-                "XDG_CACHE_HOME",
-                os.path.join(
-                    os.path.expanduser("~"), ".cache"
-                )
-            ),
-            "whisper"
-        )
+        default = os.path.join(os.path.expanduser("~"), ".cache")
+        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
 
     if name in _MODELS:
         checkpoint_file = _download(_MODELS[name], download_root, in_memory)
@@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
         checkpoint_file = open(name, "rb").read() if in_memory else name
         alignment_heads = None
     else:
-        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+        raise RuntimeError(
+            f"Model {name} not found; available models = {available_models()}"
+        )
 
-    with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
+    with (
+        io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
+    ) as fp:
         checkpoint = torch.load(fp, map_location=device)
     del checkpoint_file
 

+ 0 - 1
whisper/__main__.py

@@ -1,4 +1,3 @@
 from .transcribe import cli
 
-
 cli()

+ 14 - 6
whisper/audio.py

@@ -16,11 +16,13 @@ N_MELS = 80
 HOP_LENGTH = 160
 CHUNK_LENGTH = 30
 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
-N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000: number of frames in a mel spectrogram input
+N_FRAMES = exact_div(
+    N_SAMPLES, HOP_LENGTH
+)  # 3000: number of frames in a mel spectrogram input
 
 N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
-FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 100 mel frames in 1s (10ms each)
-TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 50 audio tokens in 1s (20ms each)
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 20ms per audio token
 
 
 def load_audio(file: str, sr: int = SAMPLE_RATE):
@@ -59,7 +61,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
     """
     if torch.is_tensor(array):
         if array.shape[axis] > length:
-            array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
+            array = array.index_select(
+                dim=axis, index=torch.arange(length, device=array.device)
+            )
 
         if array.shape[axis] < length:
             pad_widths = [(0, 0)] * array.ndim
@@ -89,11 +93,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
         )
     """
     assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
-    with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
+    with np.load(
+        os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+    ) as f:
         return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
 
 
-def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
+def log_mel_spectrogram(
+    audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
+):
     """
     Compute the log-Mel spectrogram of
 

+ 145 - 53
whisper/decoding.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
-from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import numpy as np
 import torch
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
 
 
 @torch.no_grad()
-def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
+def detect_language(
+    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
+) -> Tuple[Tensor, List[dict]]:
     """
     Detect the spoken language in the audio, and return them as list of strings, along with the ids
     of the most probable language tokens and the probability distribution over all language tokens.
@@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
     """
     if tokenizer is None:
         tokenizer = get_tokenizer(model.is_multilingual)
-    if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
-        raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
+    if (
+        tokenizer.language is None
+        or tokenizer.language_token not in tokenizer.sot_sequence
+    ):
+        raise ValueError(
+            "This model doesn't have language tokens so it can't perform lang id"
+        )
 
     single = mel.ndim == 2
     if single:
@@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
 
 @dataclass(frozen=True)
 class DecodingOptions:
-    task: str = "transcribe"  # whether to perform X->X "transcribe" or X->English "translate"
-    language: Optional[str] = None  # language that the audio is in; uses detected language if None
+    # whether to perform X->X "transcribe" or X->English "translate"
+    task: str = "transcribe"
+
+    # language that the audio is in; uses detected language if None
+    language: Optional[str] = None
 
     # sampling-related options
     temperature: float = 0.0
     sample_len: Optional[int] = None  # maximum number of tokens to sample
-    best_of: Optional[int] = None     # number of independent samples to collect, when t > 0
-    beam_size: Optional[int] = None   # number of beams in beam search, when t == 0
-    patience: Optional[float] = None  # patience in beam search (https://arxiv.org/abs/2204.05424)
+    best_of: Optional[int] = None  # number of independent sample trajectories, if t > 0
+    beam_size: Optional[int] = None  # number of beams in beam search, if t == 0
+    patience: Optional[float] = None  # patience in beam search (arxiv:2204.05424)
 
-    # options for ranking generations (either beams or best-of-N samples)
-    length_penalty: Optional[float] = None   # "alpha" in Google NMT, None defaults to length norm
+    # "alpha" in Google NMT, or None for length norm, when ranking generations
+    # to select which to return among the beams or best-of-N samples
+    length_penalty: Optional[float] = None
 
-    # prompt, prefix, and token suppression
-    prompt: Optional[Union[str, List[int]]] = None   # text or tokens for the previous context
-    prefix: Optional[Union[str, List[int]]] = None   # text or tokens to prefix the current context
-    suppress_blank: bool = True                      # this will suppress blank outputs
+    # text or tokens to feed as the prompt or the prefix; for more info:
+    # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
+    prompt: Optional[Union[str, List[int]]] = None  # for the previous context
+    prefix: Optional[Union[str, List[int]]] = None  # to prefix the current context
 
     # list of tokens ids (or comma-separated token ids) to suppress
     # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
     suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
+    suppress_blank: bool = True  # this will suppress blank outputs
 
     # timestamp sampling options
-    without_timestamps: bool = False              # use <|notimestamps|> to sample text tokens only
-    max_initial_timestamp: Optional[float] = 1.0  # the initial timestamp cannot be later than this
+    without_timestamps: bool = False  # use <|notimestamps|> to sample text tokens only
+    max_initial_timestamp: Optional[float] = 1.0
 
     # implementation details
     fp16: bool = True  # use fp16 for most of the calculation
@@ -158,7 +170,9 @@ class PyTorchInference(Inference):
 
 
 class SequenceRanker:
-    def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
+    def rank(
+        self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
+    ) -> List[int]:
         """
         Given a list of groups of samples and their cumulative log probabilities,
         return the indices of the samples in each group to select as the final result
@@ -196,7 +210,9 @@ class TokenDecoder:
     def reset(self):
         """Initialize any stateful variables for decoding a new sequence"""
 
-    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+    def update(
+        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
+    ) -> Tuple[Tensor, bool]:
         """Specify how to select the next token, based on the current trace and logits
 
         Parameters
@@ -251,7 +267,9 @@ class GreedyDecoder(TokenDecoder):
         self.temperature = temperature
         self.eot = eot
 
-    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+    def update(
+        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
+    ) -> Tuple[Tensor, bool]:
         if self.temperature == 0:
             next_tokens = logits.argmax(dim=-1)
         else:
@@ -274,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
 
 
 class BeamSearchDecoder(TokenDecoder):
-    def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
+    def __init__(
+        self,
+        beam_size: int,
+        eot: int,
+        inference: Inference,
+        patience: Optional[float] = None,
+    ):
         self.beam_size = beam_size
         self.eot = eot
         self.inference = inference
@@ -282,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
         self.max_candidates: int = round(beam_size * self.patience)
         self.finished_sequences = None
 
-        assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
+        assert (
+            self.max_candidates > 0
+        ), f"Invalid beam size ({beam_size}) or patience ({patience})"
 
     def reset(self):
         self.finished_sequences = None
 
-    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+    def update(
+        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
+    ) -> Tuple[Tensor, bool]:
         if tokens.shape[0] % self.beam_size != 0:
             raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
 
@@ -331,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
 
         # add newly finished sequences to self.finished_sequences
         assert len(self.finished_sequences) == len(finished_sequences)
-        for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
+        for previously_finished, newly_finished in zip(
+            self.finished_sequences, finished_sequences
+        ):
             for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
                 if len(previously_finished) >= self.max_candidates:
                     break  # the candidate list is full
@@ -339,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
 
         # mark as completed if all audio has enough number of samples
         completed = all(
-            len(sequences) >= self.max_candidates for sequences in self.finished_sequences
+            len(sequences) >= self.max_candidates
+            for sequences in self.finished_sequences
         )
         return tokens, completed
 
@@ -347,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
         # collect all finished sequences, including patience, and add unfinished ones if not enough
         sum_logprobs = sum_logprobs.cpu()
         for i, sequences in enumerate(self.finished_sequences):
-            if len(sequences) < self.beam_size:  # when not enough sequences are finished
+            if (
+                len(sequences) < self.beam_size
+            ):  # when not enough sequences are finished
                 for j in list(np.argsort(sum_logprobs[i]))[::-1]:
                     sequence = preceding_tokens[i, j].tolist() + [self.eot]
                     sequences[tuple(sequence)] = sum_logprobs[i][j].item()
@@ -355,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
                         break
 
         tokens: List[List[Tensor]] = [
-            [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
+            [torch.tensor(seq) for seq in sequences.keys()]
+            for sequences in self.finished_sequences
         ]
         sum_logprobs: List[List[float]] = [
             list(sequences.values()) for sequences in self.finished_sequences
@@ -399,7 +433,10 @@ class SuppressTokens(LogitFilter):
 
 class ApplyTimestampRules(LogitFilter):
     def __init__(
-        self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
+        self,
+        tokenizer: Tokenizer,
+        sample_begin: int,
+        max_initial_timestamp_index: Optional[int],
     ):
         self.tokenizer = tokenizer
         self.sample_begin = sample_begin
@@ -414,8 +451,12 @@ class ApplyTimestampRules(LogitFilter):
         for k in range(tokens.shape[0]):
             sampled_tokens = tokens[k, self.sample_begin :]
             seq = [t for t in sampled_tokens.tolist()]
-            last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
-            penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
+            last_was_timestamp = (
+                len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
+            )
+            penultimate_was_timestamp = (
+                len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
+            )
 
             if last_was_timestamp:
                 if penultimate_was_timestamp:  # has to be non-timestamp
@@ -423,7 +464,9 @@ class ApplyTimestampRules(LogitFilter):
                 else:  # cannot be normal text tokens
                     logits[k, : self.tokenizer.eot] = -np.inf
 
-            timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
+            timestamps = sampled_tokens[
+                sampled_tokens.ge(self.tokenizer.timestamp_begin)
+            ]
             if timestamps.numel() > 0:
                 # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
                 logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
@@ -434,13 +477,17 @@ class ApplyTimestampRules(LogitFilter):
 
             # apply the `max_initial_timestamp` option
             if self.max_initial_timestamp_index is not None:
-                last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+                last_allowed = (
+                    self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+                )
                 logits[:, last_allowed + 1 :] = -np.inf
 
         # if sum of probability over timestamps is above any other token, sample timestamp
         logprobs = F.log_softmax(logits.float(), dim=-1)
         for k in range(tokens.shape[0]):
-            timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
+            timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
+                dim=-1
+            )
             max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
             if timestamp_logprob > max_text_token_logprob:
                 logits[k, : self.tokenizer.timestamp_begin] = -np.inf
@@ -456,7 +503,9 @@ class DecodingTask:
         self.model = model
 
         language = options.language or "en"
-        tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
+        tokenizer = get_tokenizer(
+            model.is_multilingual, language=language, task=options.task
+        )
         self.tokenizer: Tokenizer = tokenizer
         self.options: DecodingOptions = self._verify_options(options)
 
@@ -496,9 +545,13 @@ class DecodingTask:
             precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
             max_initial_timestamp_index = None
             if options.max_initial_timestamp:
-                max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
+                max_initial_timestamp_index = round(
+                    self.options.max_initial_timestamp / precision
+                )
             self.logit_filters.append(
-                ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
+                ApplyTimestampRules(
+                    tokenizer, self.sample_begin, max_initial_timestamp_index
+                )
             )
 
     def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
@@ -509,7 +562,9 @@ class DecodingTask:
                 raise ValueError("best_of with greedy sampling (T=0) is not compatible")
         if options.patience is not None and options.beam_size is None:
             raise ValueError("patience requires beam_size to be given")
-        if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
+        if options.length_penalty is not None and not (
+            0 <= options.length_penalty <= 1
+        ):
             raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
 
         return options
@@ -519,7 +574,9 @@ class DecodingTask:
 
         if prefix := self.options.prefix:
             prefix_tokens = (
-                self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
+                self.tokenizer.encode(" " + prefix.strip())
+                if isinstance(prefix, str)
+                else prefix
             )
             if self.sample_len is not None:
                 max_prefix_len = self.n_ctx // 2 - self.sample_len
@@ -528,9 +585,15 @@ class DecodingTask:
 
         if prompt := self.options.prompt:
             prompt_tokens = (
-                self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
+                self.tokenizer.encode(" " + prompt.strip())
+                if isinstance(prompt, str)
+                else prompt
+            )
+            tokens = (
+                [self.tokenizer.sot_prev]
+                + prompt_tokens[-(self.n_ctx // 2 - 1) :]
+                + tokens
             )
-            tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
 
         return tuple(tokens)
 
@@ -554,7 +617,7 @@ class DecodingTask:
                 self.tokenizer.translate,
                 self.tokenizer.sot,
                 self.tokenizer.sot_prev,
-                self.tokenizer.sot_lm
+                self.tokenizer.sot_lm,
             ]
         )
         if self.tokenizer.no_speech is not None:
@@ -567,14 +630,21 @@ class DecodingTask:
         if self.options.fp16:
             mel = mel.half()
 
-        if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
+        if mel.shape[-2:] == (
+            self.model.dims.n_audio_ctx,
+            self.model.dims.n_audio_state,
+        ):
             # encoded audio features are given; skip audio encoding
             audio_features = mel
         else:
             audio_features = self.model.encoder(mel)
 
-        if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
-            return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
+        if audio_features.dtype != (
+            torch.float16 if self.options.fp16 else torch.float32
+        ):
+            return TypeError(
+                f"audio_features has an incorrect dtype: {audio_features.dtype}"
+            )
 
         return audio_features
 
@@ -583,7 +653,9 @@ class DecodingTask:
         lang_probs = None
 
         if self.options.language is None or self.options.task == "lang_id":
-            lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
+            lang_tokens, lang_probs = self.model.detect_language(
+                audio_features, self.tokenizer
+            )
             languages = [max(probs, key=probs.get) for probs in lang_probs]
             if self.options.language is None:
                 tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
@@ -600,7 +672,9 @@ class DecodingTask:
             for i in range(self.sample_len):
                 logits = self.inference.logits(tokens, audio_features)
 
-                if i == 0 and self.tokenizer.no_speech is not None:  # save no_speech_probs
+                if (
+                    i == 0 and self.tokenizer.no_speech is not None
+                ):  # save no_speech_probs
                     probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                     no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
 
@@ -634,8 +708,12 @@ class DecodingTask:
         languages, language_probs = self._detect_language(audio_features, tokens)
         if self.options.task == "lang_id":
             return [
-                DecodingResult(audio_features=features, language=language, language_probs=probs)
-                for features, language, probs in zip(audio_features, languages, language_probs)
+                DecodingResult(
+                    audio_features=features, language=language, language_probs=probs
+                )
+                for features, language, probs in zip(
+                    audio_features, languages, language_probs
+                )
             ]
 
         # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
@@ -656,7 +734,8 @@ class DecodingTask:
         # get the final candidates for each group, and slice between the first sampled token and EOT
         tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
         tokens: List[List[Tensor]] = [
-            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
+            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
+            for s in tokens
         ]
 
         # select the top-ranked sample in each group
@@ -665,9 +744,18 @@ class DecodingTask:
         texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
 
         sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
-        avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
+        avg_logprobs: List[float] = [
+            lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
+        ]
 
-        fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
+        fields = (
+            texts,
+            languages,
+            tokens,
+            audio_features,
+            avg_logprobs,
+            no_speech_probs,
+        )
         if len(set(map(len, fields))) != 1:
             raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
 
@@ -682,12 +770,16 @@ class DecodingTask:
                 temperature=self.options.temperature,
                 compression_ratio=compression_ratio(text),
             )
-            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
+            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
+                *fields
+            )
         ]
 
 
 @torch.no_grad()
-def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
+def decode(
+    model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
+) -> Union[DecodingResult, List[DecodingResult]]:
     """
     Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
 

+ 52 - 22
whisper/model.py

@@ -1,16 +1,15 @@
 import base64
 import gzip
 from dataclasses import dataclass
-from typing import Dict
-from typing import Iterable, Optional
+from typing import Dict, Iterable, Optional
 
 import numpy as np
 import torch
 import torch.nn.functional as F
-from torch import Tensor
-from torch import nn
+from torch import Tensor, nn
 
-from .decoding import detect_language as detect_language_function, decode as decode_function
+from .decoding import decode as decode_function
+from .decoding import detect_language as detect_language_function
 from .transcribe import transcribe as transcribe_function
 
 
@@ -36,12 +35,16 @@ class LayerNorm(nn.LayerNorm):
 class Linear(nn.Linear):
     def forward(self, x: Tensor) -> Tensor:
         return F.linear(
-            x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
+            x,
+            self.weight.to(x.dtype),
+            None if self.bias is None else self.bias.to(x.dtype),
         )
 
 
 class Conv1d(nn.Conv1d):
-    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
+    def _conv_forward(
+        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
+    ) -> Tensor:
         return super()._conv_forward(
             x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
         )
@@ -87,7 +90,9 @@ class MultiHeadAttention(nn.Module):
         wv, qk = self.qkv_attention(q, k, v, mask)
         return self.out(wv), qk
 
-    def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
+    def qkv_attention(
+        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+    ):
         n_batch, n_ctx, n_state = q.shape
         scale = (n_state // self.n_head) ** -0.25
         q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -110,11 +115,15 @@ class ResidualAttentionBlock(nn.Module):
         self.attn = MultiHeadAttention(n_state, n_head)
         self.attn_ln = LayerNorm(n_state)
 
-        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
+        self.cross_attn = (
+            MultiHeadAttention(n_state, n_head) if cross_attention else None
+        )
         self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
 
         n_mlp = n_state * 4
-        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
+        self.mlp = nn.Sequential(
+            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+        )
         self.mlp_ln = LayerNorm(n_state)
 
     def forward(
@@ -132,7 +141,9 @@ class ResidualAttentionBlock(nn.Module):
 
 
 class AudioEncoder(nn.Module):
-    def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+    def __init__(
+        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+    ):
         super().__init__()
         self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
         self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@@ -163,14 +174,19 @@ class AudioEncoder(nn.Module):
 
 
 class TextDecoder(nn.Module):
-    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+    def __init__(
+        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+    ):
         super().__init__()
 
         self.token_embedding = nn.Embedding(n_vocab, n_state)
         self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
 
         self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
-            [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
+            [
+                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
+                for _ in range(n_layer)
+            ]
         )
         self.ln = LayerNorm(n_state)
 
@@ -185,14 +201,19 @@ class TextDecoder(nn.Module):
             the encoded audio features to be attended on
         """
         offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
-        x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
+        x = (
+            self.token_embedding(x)
+            + self.positional_embedding[offset : offset + x.shape[-1]]
+        )
         x = x.to(xa.dtype)
 
         for block in self.blocks:
             x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
 
         x = self.ln(x)
-        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
+        logits = (
+            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+        ).float()
 
         return logits
 
@@ -216,13 +237,19 @@ class Whisper(nn.Module):
             self.dims.n_text_layer,
         )
         # use the last half layers for alignment by default; see `set_alignment_heads()` below
-        all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
-        all_heads[self.dims.n_text_layer // 2:] = True
+        all_heads = torch.zeros(
+            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
+        )
+        all_heads[self.dims.n_text_layer // 2 :] = True
         self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
 
     def set_alignment_heads(self, dump: bytes):
-        array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
-        mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
+        array = np.frombuffer(
+            gzip.decompress(base64.b85decode(dump)), dtype=bool
+        ).copy()
+        mask = torch.from_numpy(array).reshape(
+            self.dims.n_text_layer, self.dims.n_text_head
+        )
         self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
 
     def embed_audio(self, mel: torch.Tensor):
@@ -231,7 +258,9 @@ class Whisper(nn.Module):
     def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
         return self.decoder(tokens, audio_features)
 
-    def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
+    def forward(
+        self, mel: torch.Tensor, tokens: torch.Tensor
+    ) -> Dict[str, torch.Tensor]:
         return self.decoder(tokens, self.encoder(mel))
 
     @property
@@ -260,8 +289,9 @@ class Whisper(nn.Module):
         hooks = []
 
         def save_to_cache(module, _, output):
-            if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
-                cache[module] = output  # save as-is, for the first token or cross attention
+            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
+                # save as-is, for the first token or cross attention
+                cache[module] = output
             else:
                 cache[module] = torch.cat([cache[module], output], dim=1).detach()
             return cache[module]

+ 2 - 2
whisper/normalizers/__init__.py

@@ -1,2 +1,2 @@
-from .basic import BasicTextNormalizer
-from .english import EnglishTextNormalizer
+from .basic import BasicTextNormalizer as BasicTextNormalizer
+from .english import EnglishTextNormalizer as EnglishTextNormalizer

+ 8 - 3
whisper/normalizers/basic.py

@@ -48,13 +48,16 @@ def remove_symbols(s: str):
     Replace any other markers, symbols, punctuations with a space, keeping diacritics
     """
     return "".join(
-        " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
+        " " if unicodedata.category(c)[0] in "MSP" else c
+        for c in unicodedata.normalize("NFKC", s)
     )
 
 
 class BasicTextNormalizer:
     def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
-        self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
+        self.clean = (
+            remove_symbols_and_diacritics if remove_diacritics else remove_symbols
+        )
         self.split_letters = split_letters
 
     def __call__(self, s: str):
@@ -66,6 +69,8 @@ class BasicTextNormalizer:
         if self.split_letters:
             s = " ".join(regex.findall(r"\X", s, regex.U))
 
-        s = re.sub(r"\s+", " ", s)  # replace any successive whitespace characters with a space
+        s = re.sub(
+            r"\s+", " ", s
+        )  # replace any successive whitespace characters with a space
 
         return s

+ 14 - 7
whisper/normalizers/english.py

@@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
             name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
         }
         self.tens_ordinal = {
-            name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
+            name.replace("y", "ieth"): (value, "th")
+            for name, value in self.tens.items()
         }
         self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
 
@@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
         self.multipliers_ordinal = {
             name + "th": (value, "th") for name, value in self.multipliers.items()
         }
-        self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
+        self.multipliers_suffixed = {
+            **self.multipliers_plural,
+            **self.multipliers_ordinal,
+        }
         self.decimals = {*self.ones, *self.tens, *self.zeros}
 
         self.preceding_prefixers = {
@@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
             "cents": "¢",
         }
         self.prefixes = set(
-            list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
+            list(self.preceding_prefixers.values())
+            + list(self.following_prefixers.values())
         )
         self.suffixers = {
             "per": {"cent": "%"},
@@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
                 if value is None:
                     value = ones
                 elif isinstance(value, str) or prev in self.ones:
-                    if prev in self.tens and ones < 10:  # replace the last zero with the digit
+                    if (
+                        prev in self.tens and ones < 10
+                    ):  # replace the last zero with the digit
                         assert value[-1] == "0"
                         value = value[:-1] + str(ones)
                     else:
@@ -522,14 +529,14 @@ class EnglishTextNormalizer:
         s = re.sub(r"[<\[][^>\]]*[>\]]", "", s)  # remove words between brackets
         s = re.sub(r"\(([^)]+?)\)", "", s)  # remove words between parenthesis
         s = re.sub(self.ignore_patterns, "", s)
-        s = re.sub(r"\s+'", "'", s)  # standardize when there's a space before an apostrophe
+        s = re.sub(r"\s+'", "'", s)  # when there's a space before an apostrophe
 
         for pattern, replacement in self.replacers.items():
             s = re.sub(pattern, replacement, s)
 
         s = re.sub(r"(\d),(\d)", r"\1\2", s)  # remove commas between digits
         s = re.sub(r"\.([^0-9]|$)", r" \1", s)  # remove periods not followed by numbers
-        s = remove_symbols_and_diacritics(s, keep=".%$¢€£")  # keep some symbols for numerics
+        s = remove_symbols_and_diacritics(s, keep=".%$¢€£")  # keep numeric symbols
 
         s = self.standardize_numbers(s)
         s = self.standardize_spellings(s)
@@ -538,6 +545,6 @@ class EnglishTextNormalizer:
         s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
         s = re.sub(r"([^0-9])%", r"\1 ", s)
 
-        s = re.sub(r"\s+", " ", s)  # replace any successive whitespace characters with a space
+        s = re.sub(r"\s+", " ", s)  # replace any successive whitespaces with a space
 
         return s

+ 34 - 16
whisper/timing.py

@@ -1,7 +1,7 @@
 import subprocess
 import warnings
 from dataclasses import dataclass
-from typing import List, TYPE_CHECKING
+from typing import TYPE_CHECKING, List
 
 import numba
 import numpy as np
@@ -26,13 +26,16 @@ def median_filter(x: torch.Tensor, filter_width: int):
         # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
         x = x[None, None, :]
 
-    assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
+    assert (
+        filter_width > 0 and filter_width % 2 == 1
+    ), "`filter_width` should be an odd number"
 
     result = None
     x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
     if x.is_cuda:
         try:
             from .triton_ops import median_filter_cuda
+
             result = median_filter_cuda(x, filter_width)
         except (RuntimeError, subprocess.CalledProcessError):
             warnings.warn(
@@ -49,6 +52,7 @@ def median_filter(x: torch.Tensor, filter_width: int):
 
     return result
 
+
 @numba.jit
 def backtrace(trace: np.ndarray):
     i = trace.shape[0] - 1
@@ -106,7 +110,9 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
     M, N = x.shape
     assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
 
-    x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
+    x_skew = (
+        F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
+    )
     x_skew = x_skew.T.contiguous()
     cost = torch.ones(N + M + 2, M + 2) * np.inf
     cost[0, 0] = 0
@@ -122,10 +128,12 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
         trace.stride(0),
         N,
         M,
-        BLOCK_SIZE=BLOCK_SIZE
+        BLOCK_SIZE=BLOCK_SIZE,
     )
 
-    trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
+    trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
+        :, : N + 1
+    ]
     return backtrace(trace.cpu().numpy())
 
 
@@ -181,8 +189,10 @@ def find_alignment(
 
     with torch.no_grad():
         logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
-        token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
-        text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
+        sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
+        token_probs = sampled_logits.softmax(dim=-1)
+        text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
+        text_token_probs = text_token_probs.tolist()
 
     for hook in hooks:
         hook.remove()
@@ -196,7 +206,7 @@ def find_alignment(
     weights = median_filter(weights, medfilt_width)
 
     matrix = weights.mean(axis=0)
-    matrix = matrix[len(tokenizer.sot_sequence):-1]
+    matrix = matrix[len(tokenizer.sot_sequence) : -1]
     text_indices, time_indices = dtw(-matrix)
 
     words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
@@ -207,7 +217,8 @@ def find_alignment(
     start_times = jump_times[word_boundaries[:-1]]
     end_times = jump_times[word_boundaries[1:]]
     word_probabilities = [
-        np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
+        np.mean(text_token_probs[i:j])
+        for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
     ]
 
     # hack: ensure the first and second word is not longer than twice the median word duration.
@@ -218,7 +229,8 @@ def find_alignment(
         median_duration = np.median(word_durations)
         max_duration = median_duration * 2
         if len(word_durations) >= 2 and word_durations[1] > max_duration:
-            end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
+            boundary = max(end_times[2] / 2, end_times[2] - max_duration)
+            end_times[0] = start_times[1] = boundary
         if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
             start_times[0] = max(0, end_times[0] - max_duration)
 
@@ -271,19 +283,20 @@ def add_word_timestamps(
     tokenizer: Tokenizer,
     mel: torch.Tensor,
     num_frames: int,
-    prepend_punctuations: str = "\"\'“¿([{-",
-    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
-    **hyperparams,
+    prepend_punctuations: str = "\"'“¿([{-",
+    append_punctuations: str = "\"'.。,,!!??::”)]}、",
+    **kwargs,
 ):
     if len(segments) == 0:
         return
 
     text_tokens = [t for segment in segments for t in segment["tokens"]]
-    alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
+    alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
 
     time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
-    token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
+    segment_lengths = [len(s["tokens"]) for s in segments]
+    token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
 
     for segment in segments:
         segment["words"] = []
@@ -295,7 +308,12 @@ def add_word_timestamps(
             start = round(time_offset + timing.start, 2)
             end = round(time_offset + timing.end, 2)
             segment["words"].append(
-                dict(word=timing.word, start=start, end=end, probability=timing.probability)
+                dict(
+                    word=timing.word,
+                    start=start,
+                    end=end,
+                    probability=timing.probability,
+                )
             )
 
     for segment in segments:

+ 19 - 9
whisper/tokenizer.py

@@ -1,7 +1,7 @@
 import os
 import string
 from dataclasses import dataclass
-from functools import lru_cache, cached_property
+from functools import cached_property, lru_cache
 from typing import List, Optional, Tuple, Union
 
 import numpy as np
@@ -138,7 +138,9 @@ class Tokenizer:
     def encode(self, text, **kwargs):
         return self.tokenizer.encode(text, **kwargs)
 
-    def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
+    def decode(
+        self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
+    ):
         return self.tokenizer.decode(token_ids, **kwargs)
 
     def decode_with_timestamps(self, tokens) -> str:
@@ -154,8 +156,9 @@ class Tokenizer:
                 outputs.append([])
             else:
                 outputs[-1].append(token)
-        outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
-        return "".join(outputs)
+        return "".join(
+            [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
+        )
 
     @cached_property
     def eot(self) -> int:
@@ -197,7 +200,7 @@ class Tokenizer:
     def language_token(self) -> int:
         """Returns the token id corresponding to the value of the `language` field"""
         if self.language is None:
-            raise ValueError(f"This tokenizer does not have language token configured")
+            raise ValueError("This tokenizer does not have language token configured")
 
         additional_tokens = dict(
             zip(
@@ -242,8 +245,10 @@ class Tokenizer:
 
         keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
         """
-        symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
-        symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
+        symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
+        symbols += (
+            "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
+        )
 
         # symbols that may be a single token or multiple tokens depending on the tokenizer.
         # In case they're multiple tokens, suppress the first token, which is safe because:
@@ -255,7 +260,10 @@ class Tokenizer:
         # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
         result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
         for symbol in symbols + list(miscellaneous):
-            for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
+            for tokens in [
+                self.tokenizer.encode(symbol),
+                self.tokenizer.encode(" " + symbol),
+            ]:
                 if len(tokens) == 1 or symbol in miscellaneous:
                     result.add(tokens[0])
 
@@ -367,4 +375,6 @@ def get_tokenizer(
     if task is not None:
         sot_sequence.append(transcribe if task == "transcribe" else translate)
 
-    return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
+    return Tokenizer(
+        tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
+    )

+ 104 - 39
whisper/transcribe.py

@@ -1,17 +1,32 @@
 import argparse
 import os
 import warnings
-from typing import Optional, Tuple, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple, Union
 
 import numpy as np
 import torch
 import tqdm
 
-from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
+from .audio import (
+    FRAMES_PER_SECOND,
+    HOP_LENGTH,
+    N_FRAMES,
+    SAMPLE_RATE,
+    log_mel_spectrogram,
+    pad_or_trim,
+)
 from .decoding import DecodingOptions, DecodingResult
 from .timing import add_word_timestamps
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
-from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
+from .utils import (
+    exact_div,
+    format_timestamp,
+    get_writer,
+    make_safe,
+    optional_float,
+    optional_int,
+    str2bool,
+)
 
 if TYPE_CHECKING:
     from .model import Whisper
@@ -29,8 +44,8 @@ def transcribe(
     condition_on_previous_text: bool = True,
     initial_prompt: Optional[str] = None,
     word_timestamps: bool = False,
-    prepend_punctuations: str = "\"\'“¿([{-",
-    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
+    prepend_punctuations: str = "\"'“¿([{-",
+    append_punctuations: str = "\"'.。,,!!??::”)]}、",
     **decode_options,
 ):
     """
@@ -108,12 +123,16 @@ def transcribe(
             decode_options["language"] = "en"
         else:
             if verbose:
-                print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
+                print(
+                    "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
+                )
             mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
             _, probs = model.detect_language(mel_segment)
             decode_options["language"] = max(probs, key=probs.get)
             if verbose is not None:
-                print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
+                print(
+                    f"Detected language: {LANGUAGES[decode_options['language']].title()}"
+                )
 
     language: str = decode_options["language"]
     task: str = decode_options.get("task", "transcribe")
@@ -123,7 +142,9 @@ def transcribe(
         warnings.warn("Word-level timestamps on translations may not be reliable.")
 
     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
-        temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
+        temperatures = (
+            [temperature] if isinstance(temperature, (int, float)) else temperature
+        )
         decode_result = None
 
         for t in temperatures:
@@ -140,9 +161,15 @@ def transcribe(
             decode_result = model.decode(segment, options)
 
             needs_fallback = False
-            if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
+            if (
+                compression_ratio_threshold is not None
+                and decode_result.compression_ratio > compression_ratio_threshold
+            ):
                 needs_fallback = True  # too repetitive
-            if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
+            if (
+                logprob_threshold is not None
+                and decode_result.avg_logprob < logprob_threshold
+            ):
                 needs_fallback = True  # average log probability is too low
 
             if not needs_fallback:
@@ -186,7 +213,9 @@ def transcribe(
 
     # show the progress bar when verbose is False (if True, transcribed text will be printed)
     num_frames = mel.shape[-1]
-    with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
+    with tqdm.tqdm(
+        total=num_frames, unit="frames", disable=verbose is not False
+    ) as pbar:
         while seek < num_frames:
             time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
             mel_segment = mel[:, seek:]
@@ -201,7 +230,10 @@ def transcribe(
             if no_speech_threshold is not None:
                 # no voice activity check
                 should_skip = result.no_speech_prob > no_speech_threshold
-                if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
+                if (
+                    logprob_threshold is not None
+                    and result.avg_logprob > logprob_threshold
+                ):
                     # don't skip if the logprob is high enough, despite the no_speech_prob
                     should_skip = False
 
@@ -214,22 +246,35 @@ def transcribe(
             current_tokens = []
 
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
-            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
-            if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
-                if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
+            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
+                0
+            ].add_(1)
+            if (
+                len(consecutive) > 0
+            ):  # if the output contains two consecutive timestamp tokens
+                if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
+                    False,
+                    True,
+                ]:
                     consecutive = consecutive.tolist() + [len(tokens)]
 
                 last_slice = 0
                 for current_slice in consecutive:
                     sliced_tokens = tokens[last_slice:current_slice]
-                    start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
-                    end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
-                    current_segments.append(new_segment(
-                        start=time_offset + start_timestamp_pos * time_precision,
-                        end=time_offset + end_timestamp_pos * time_precision,
-                        tokens=sliced_tokens,
-                        result=result,
-                    ))
+                    start_timestamp_pos = (
+                        sliced_tokens[0].item() - tokenizer.timestamp_begin
+                    )
+                    end_timestamp_pos = (
+                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
+                    )
+                    current_segments.append(
+                        new_segment(
+                            start=time_offset + start_timestamp_pos * time_precision,
+                            end=time_offset + end_timestamp_pos * time_precision,
+                            tokens=sliced_tokens,
+                            result=result,
+                        )
+                    )
                     current_tokens.append(sliced_tokens.tolist())
                     last_slice = current_slice
 
@@ -238,23 +283,32 @@ def transcribe(
                     seek += segment_size
                 else:
                     # otherwise, ignore the unfinished segment and seek to the last timestamp
-                    last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
+                    last_timestamp_pos = (
+                        tokens[last_slice - 1].item() - tokenizer.timestamp_begin
+                    )
                     seek += last_timestamp_pos * input_stride
                 all_tokens.extend(tokens[: last_slice + 1].tolist())
             else:
                 duration = segment_duration
                 timestamps = tokens[timestamp_tokens.nonzero().flatten()]
-                if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
+                if (
+                    len(timestamps) > 0
+                    and timestamps[-1].item() != tokenizer.timestamp_begin
+                ):
                     # no consecutive timestamps but it has a timestamp; use the last one.
-                    last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
+                    last_timestamp_pos = (
+                        timestamps[-1].item() - tokenizer.timestamp_begin
+                    )
                     duration = last_timestamp_pos * time_precision
 
-                current_segments.append(new_segment(
-                    start=time_offset,
-                    end=time_offset + duration,
-                    tokens=tokens,
-                    result=result,
-                ))
+                current_segments.append(
+                    new_segment(
+                        start=time_offset,
+                        end=time_offset + duration,
+                        tokens=tokens,
+                        result=result,
+                    )
+                )
                 current_tokens.append(tokens.tolist())
                 seek += segment_size
 
@@ -272,9 +326,13 @@ def transcribe(
                     prepend_punctuations=prepend_punctuations,
                     append_punctuations=append_punctuations,
                 )
-                word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
+                word_end_timestamps = [
+                    w["end"] for s in current_segments for w in s["words"]
+                ]
                 if len(consecutive) > 0 and len(word_end_timestamps) > 0:
-                    seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
+                    seek_shift = round(
+                        (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
+                    )
                     if seek_shift > 0:
                         seek = previous_seek + seek_shift
 
@@ -293,21 +351,24 @@ def transcribe(
                     current_tokens[i] = []
 
             all_segments.extend(current_segments)
-            all_tokens.extend([token for segment in current_tokens for token in segment])
+            all_tokens.extend(
+                [token for segment in current_tokens for token in segment]
+            )
 
             # update progress bar
             pbar.update(min(num_frames, seek) - previous_seek)
 
     return dict(
-        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
+        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
         segments=all_segments,
-        language=language
+        language=language,
     )
 
 
 def cli():
     from . import available_models
 
+    # fmt: off
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
     parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
@@ -339,6 +400,7 @@ def cli():
     parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
     parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
     parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
+    # fmt: on
 
     args = parser.parse_args().__dict__
     model_name: str = args.pop("model")
@@ -350,7 +412,9 @@ def cli():
 
     if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
         if args["language"] is not None:
-            warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
+            warnings.warn(
+                f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
+            )
         args["language"] = "en"
 
     temperature = args.pop("temperature")
@@ -363,6 +427,7 @@ def cli():
         torch.set_num_threads(threads)
 
     from . import load_model
+
     model = load_model(model_name, device=device, download_root=model_dir)
 
     writer = get_writer(output_format, output_dir)
@@ -371,5 +436,5 @@ def cli():
         writer(result, audio_path)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     cli()

+ 41 - 24
whisper/triton_ops.py

@@ -1,8 +1,7 @@
-import math
+from functools import lru_cache
 
 import numpy as np
 import torch
-from functools import lru_cache
 
 try:
     import triton
@@ -12,7 +11,9 @@ except ImportError:
 
 
 @triton.jit
-def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
+def dtw_kernel(
+    cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
+):
     offsets = tl.arange(0, BLOCK_SIZE)
     mask = offsets < M
 
@@ -42,37 +43,53 @@ def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_
 @lru_cache(maxsize=None)
 def median_kernel(filter_width: int):
     @triton.jit
-    def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr):  # x.shape[-1] == filter_width
+    def kernel(
+        y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
+    ):  # x.shape[-1] == filter_width
         row_idx = tl.program_id(0)
         offsets = tl.arange(0, BLOCK_SIZE)
         mask = offsets < y_stride
 
-        x_ptr = x + row_idx * x_stride
+        x_ptr = x + row_idx * x_stride  # noqa: F841
         y_ptr = y + row_idx * y_stride
 
-        LOAD_ALL_ROWS_HERE
+        LOAD_ALL_ROWS_HERE  # noqa: F821
 
-        BUBBLESORT_HERE
+        BUBBLESORT_HERE  # noqa: F821
 
-        tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
+        tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)  # noqa: F821
 
     kernel = triton.JITFunction(kernel.fn)
-    kernel.src = kernel.src.replace("    LOAD_ALL_ROWS_HERE", "\n".join([
-        f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
-        for i in range(filter_width)
-    ]))
-    kernel.src = kernel.src.replace("    BUBBLESORT_HERE", "\n\n".join([
-        "\n\n".join([
-            "\n".join([
-                f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
-                f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
-                f"    row{j} = smaller",
-                f"    row{j + 1} = larger",
-            ])
-            for j in range(filter_width - i - 1)
-        ])
-        for i in range(filter_width // 2 + 1)
-    ]))
+    kernel.src = kernel.src.replace(
+        "    LOAD_ALL_ROWS_HERE",
+        "\n".join(
+            [
+                f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
+                for i in range(filter_width)
+            ]
+        ),
+    )
+    kernel.src = kernel.src.replace(
+        "    BUBBLESORT_HERE",
+        "\n\n".join(
+            [
+                "\n\n".join(
+                    [
+                        "\n".join(
+                            [
+                                f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
+                                f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
+                                f"    row{j} = smaller",
+                                f"    row{j + 1} = larger",
+                            ]
+                        )
+                        for j in range(filter_width - i - 1)
+                    ]
+                )
+                for i in range(filter_width // 2 + 1)
+            ]
+        ),
+    )
     kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
 
     return kernel

+ 24 - 12
whisper/utils.py

@@ -7,11 +7,14 @@ from typing import Callable, TextIO
 system_encoding = sys.getdefaultencoding()
 
 if system_encoding != "utf-8":
+
     def make_safe(string):
         # replaces any character not representable using the system default encoding with an '?',
         # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
         return string.encode(system_encoding, errors="replace").decode(system_encoding)
+
 else:
+
     def make_safe(string):
         # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
         return string
@@ -43,7 +46,9 @@ def compression_ratio(text) -> float:
     return len(text_bytes) / len(zlib.compress(text_bytes))
 
 
-def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
+def format_timestamp(
+    seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
+):
     assert seconds >= 0, "non-negative timestamp expected"
     milliseconds = round(seconds * 1000.0)
 
@@ -57,7 +62,9 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
     milliseconds -= seconds * 1_000
 
     hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
-    return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+    return (
+        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+    )
 
 
 class ResultWriter:
@@ -68,7 +75,9 @@ class ResultWriter:
 
     def __call__(self, result: dict, audio_path: str):
         audio_basename = os.path.basename(audio_path)
-        output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
+        output_path = os.path.join(
+            self.output_dir, audio_basename + "." + self.extension
+        )
 
         with open(output_path, "w", encoding="utf-8") as f:
             self.write_result(result, file=f)
@@ -82,7 +91,7 @@ class WriteTXT(ResultWriter):
 
     def write_result(self, result: dict, file: TextIO):
         for segment in result["segments"]:
-            print(segment['text'].strip(), file=file, flush=True)
+            print(segment["text"].strip(), file=file, flush=True)
 
 
 class SubtitlesWriter(ResultWriter):
@@ -93,7 +102,7 @@ class SubtitlesWriter(ResultWriter):
         for segment in result["segments"]:
             segment_start = self.format_timestamp(segment["start"])
             segment_end = self.format_timestamp(segment["end"])
-            segment_text = segment['text'].strip().replace('-->', '->')
+            segment_text = segment["text"].strip().replace("-->", "->")
 
             if word_timings := segment.get("words", None):
                 all_words = [timing["word"] for timing in word_timings]
@@ -106,7 +115,10 @@ class SubtitlesWriter(ResultWriter):
                         yield last, start, segment_text
 
                     yield start, end, "".join(
-                        [f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)]
+                        [
+                            f"<u>{word}</u>" if j == i else word
+                            for j, word in enumerate(all_words)
+                        ]
                     )
                     last = end
 
@@ -126,7 +138,7 @@ class SubtitlesWriter(ResultWriter):
 class WriteVTT(SubtitlesWriter):
     extension: str = "vtt"
     always_include_hours: bool = False
-    decimal_marker: str = '.'
+    decimal_marker: str = "."
 
     def write_result(self, result: dict, file: TextIO):
         print("WEBVTT\n", file=file)
@@ -137,7 +149,7 @@ class WriteVTT(SubtitlesWriter):
 class WriteSRT(SubtitlesWriter):
     extension: str = "srt"
     always_include_hours: bool = True
-    decimal_marker: str = ','
+    decimal_marker: str = ","
 
     def write_result(self, result: dict, file: TextIO):
         for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
@@ -153,14 +165,15 @@ class WriteTSV(ResultWriter):
     an environment setting a language encoding that causes the decimal in a floating point number
     to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
     """
+
     extension: str = "tsv"
 
     def write_result(self, result: dict, file: TextIO):
         print("start", "end", "text", sep="\t", file=file)
         for segment in result["segments"]:
-            print(round(1000 * segment['start']), file=file, end="\t")
-            print(round(1000 * segment['end']), file=file, end="\t")
-            print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
+            print(round(1000 * segment["start"]), file=file, end="\t")
+            print(round(1000 * segment["end"]), file=file, end="\t")
+            print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
 
 
 class WriteJSON(ResultWriter):
@@ -189,4 +202,3 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
         return write_all
 
     return writers[output_format](output_dir)
-