Browse Source

apply formatting with `black` (#1038)

* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
Jong Wook Kim 2 years ago
parent
commit
b80bcf610d

+ 4 - 0
.flake8

@@ -0,0 +1,4 @@
+[flake8]
+per-file-ignores =
+    */__init__.py: F401
+

+ 3 - 0
.github/workflows/test.yml

@@ -22,4 +22,7 @@ jobs:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
       - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
       - run: pip install .["dev"]
       - run: pip install .["dev"]
+      - run: black --check --diff -t py38 --include '(\.pyi?)$' .
+      - run: isort --check --diff .
+      - run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
       - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
       - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

+ 8 - 0
pyproject.toml

@@ -0,0 +1,8 @@
+[tool.black]
+
+[tool.isort]
+profile = "black"
+include_trailing_comma = true
+line_length = 88
+multi_line_output = 3
+

+ 8 - 4
setup.py

@@ -2,7 +2,7 @@ import os
 import sys
 import sys
 
 
 import pkg_resources
 import pkg_resources
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
 
 
 
 
 def read_version(fname="whisper/version.py"):
 def read_version(fname="whisper/version.py"):
@@ -16,7 +16,10 @@ if sys.platform.startswith("linux"):
     try:
     try:
         import re
         import re
         import subprocess
         import subprocess
-        version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
+
+        version_line = (
+            subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
+        )
         major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
         major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
         if (int(major), int(minor)) < (11, 4):
         if (int(major), int(minor)) < (11, 4):
             # the last version supporting CUDA < 11.4
             # the last version supporting CUDA < 11.4
@@ -38,7 +41,8 @@ setup(
     url="https://github.com/openai/whisper",
     url="https://github.com/openai/whisper",
     license="MIT",
     license="MIT",
     packages=find_packages(exclude=["tests*"]),
     packages=find_packages(exclude=["tests*"]),
-    install_requires=requirements + [
+    install_requires=requirements
+    + [
         str(r)
         str(r)
         for r in pkg_resources.parse_requirements(
         for r in pkg_resources.parse_requirements(
             open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
             open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -48,5 +52,5 @@ setup(
         "console_scripts": ["whisper=whisper.transcribe:cli"],
         "console_scripts": ["whisper=whisper.transcribe:cli"],
     },
     },
     include_package_data=True,
     include_package_data=True,
-    extras_require={"dev": ["pytest", "scipy"]},
+    extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
 )
 )

+ 1 - 1
tests/test_audio.py

@@ -2,7 +2,7 @@ import os.path
 
 
 import numpy as np
 import numpy as np
 
 
-from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
+from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
 
 
 
 
 def test_audio():
 def test_audio():

+ 4 - 1
tests/test_normalizer.py

@@ -1,7 +1,10 @@
 import pytest
 import pytest
 
 
 from whisper.normalizers import EnglishTextNormalizer
 from whisper.normalizers import EnglishTextNormalizer
-from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
+from whisper.normalizers.english import (
+    EnglishNumberNormalizer,
+    EnglishSpellingNormalizer,
+)
 
 
 
 
 @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
 @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])

+ 15 - 6
tests/test_timing.py

@@ -1,16 +1,21 @@
-import pytest
 import numpy as np
 import numpy as np
+import pytest
 import scipy.ndimage
 import scipy.ndimage
 import torch
 import torch
 
 
 from whisper.timing import dtw_cpu, dtw_cuda, median_filter
 from whisper.timing import dtw_cpu, dtw_cuda, median_filter
 
 
-
 sizes = [
 sizes = [
-    (10, 20), (32, 16), (123, 1500), (234, 189),
+    (10, 20),
+    (32, 16),
+    (123, 1500),
+    (234, 189),
 ]
 ]
 shapes = [
 shapes = [
-    (10,), (1, 15),  (4, 5, 345), (6, 12, 240, 512),
+    (10,),
+    (1, 15),
+    (4, 5, 345),
+    (6, 12, 240, 512),
 ]
 ]
 
 
 
 
@@ -68,8 +73,12 @@ def test_median_filter(shape):
 
 
         # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
         # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
         pad_width = filter_width // 2
         pad_width = filter_width // 2
-        padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
-        scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
+        padded_x = np.pad(
+            x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
+        )
+        scipy_filtered = scipy.ndimage.median_filter(
+            padded_x, [1] * (x.ndim - 1) + [filter_width]
+        )
         scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
         scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
 
 
         assert np.allclose(filtered, scipy_filtered)
         assert np.allclose(filtered, scipy_filtered)

+ 3 - 1
tests/test_transcribe.py

@@ -13,7 +13,9 @@ def test_transcribe(model_name: str):
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
     audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
 
 
     language = "en" if model_name.endswith(".en") else None
     language = "en" if model_name.endswith(".en") else None
-    result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
+    result = model.transcribe(
+        audio_path, language=language, temperature=0.0, word_timestamps=True
+    )
     assert result["language"] == "en"
     assert result["language"] == "en"
 
 
     transcription = result["text"].lower()
     transcription = result["text"].lower()

+ 30 - 20
whisper/__init__.py

@@ -10,11 +10,10 @@ from tqdm import tqdm
 
 
 from .audio import load_audio, log_mel_spectrogram, pad_or_trim
 from .audio import load_audio, log_mel_spectrogram, pad_or_trim
 from .decoding import DecodingOptions, DecodingResult, decode, detect_language
 from .decoding import DecodingOptions, DecodingResult, decode, detect_language
-from .model import Whisper, ModelDimensions
+from .model import ModelDimensions, Whisper
 from .transcribe import transcribe
 from .transcribe import transcribe
 from .version import __version__
 from .version import __version__
 
 
-
 _MODELS = {
 _MODELS = {
     "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
     "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
     "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
     "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@@ -41,12 +40,11 @@ _ALIGNMENT_HEADS = {
     "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
     "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
     "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
     "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
     "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
     "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
-    "large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
-    "large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
+    "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
+    "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
 }
 }
 
 
 
 
-
 def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
 def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
     os.makedirs(root, exist_ok=True)
     os.makedirs(root, exist_ok=True)
 
 
@@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
             return model_bytes if in_memory else download_target
             return model_bytes if in_memory else download_target
         else:
         else:
-            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+            warnings.warn(
+                f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
+            )
 
 
     with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
     with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
-        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
+        with tqdm(
+            total=int(source.info().get("Content-Length")),
+            ncols=80,
+            unit="iB",
+            unit_scale=True,
+            unit_divisor=1024,
+        ) as loop:
             while True:
             while True:
                 buffer = source.read(8192)
                 buffer = source.read(8192)
                 if not buffer:
                 if not buffer:
@@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
 
 
     model_bytes = open(download_target, "rb").read()
     model_bytes = open(download_target, "rb").read()
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
-        raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
+        raise RuntimeError(
+            "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
+        )
 
 
     return model_bytes if in_memory else download_target
     return model_bytes if in_memory else download_target
 
 
@@ -86,7 +94,12 @@ def available_models() -> List[str]:
     return list(_MODELS.keys())
     return list(_MODELS.keys())
 
 
 
 
-def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
+def load_model(
+    name: str,
+    device: Optional[Union[str, torch.device]] = None,
+    download_root: str = None,
+    in_memory: bool = False,
+) -> Whisper:
     """
     """
     Load a Whisper ASR model
     Load a Whisper ASR model
 
 
@@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
     if device is None:
     if device is None:
         device = "cuda" if torch.cuda.is_available() else "cpu"
         device = "cuda" if torch.cuda.is_available() else "cpu"
     if download_root is None:
     if download_root is None:
-        download_root = os.path.join(
-            os.getenv(
-                "XDG_CACHE_HOME",
-                os.path.join(
-                    os.path.expanduser("~"), ".cache"
-                )
-            ),
-            "whisper"
-        )
+        default = os.path.join(os.path.expanduser("~"), ".cache")
+        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
 
 
     if name in _MODELS:
     if name in _MODELS:
         checkpoint_file = _download(_MODELS[name], download_root, in_memory)
         checkpoint_file = _download(_MODELS[name], download_root, in_memory)
@@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
         checkpoint_file = open(name, "rb").read() if in_memory else name
         checkpoint_file = open(name, "rb").read() if in_memory else name
         alignment_heads = None
         alignment_heads = None
     else:
     else:
-        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+        raise RuntimeError(
+            f"Model {name} not found; available models = {available_models()}"
+        )
 
 
-    with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
+    with (
+        io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
+    ) as fp:
         checkpoint = torch.load(fp, map_location=device)
         checkpoint = torch.load(fp, map_location=device)
     del checkpoint_file
     del checkpoint_file
 
 

+ 0 - 1
whisper/__main__.py

@@ -1,4 +1,3 @@
 from .transcribe import cli
 from .transcribe import cli
 
 
-
 cli()
 cli()

+ 14 - 6
whisper/audio.py

@@ -16,11 +16,13 @@ N_MELS = 80
 HOP_LENGTH = 160
 HOP_LENGTH = 160
 CHUNK_LENGTH = 30
 CHUNK_LENGTH = 30
 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
-N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000: number of frames in a mel spectrogram input
+N_FRAMES = exact_div(
+    N_SAMPLES, HOP_LENGTH
+)  # 3000: number of frames in a mel spectrogram input
 
 
 N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
 N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
-FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 100 mel frames in 1s (10ms each)
-TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 50 audio tokens in 1s (20ms each)
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 20ms per audio token
 
 
 
 
 def load_audio(file: str, sr: int = SAMPLE_RATE):
 def load_audio(file: str, sr: int = SAMPLE_RATE):
@@ -59,7 +61,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
     """
     """
     if torch.is_tensor(array):
     if torch.is_tensor(array):
         if array.shape[axis] > length:
         if array.shape[axis] > length:
-            array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
+            array = array.index_select(
+                dim=axis, index=torch.arange(length, device=array.device)
+            )
 
 
         if array.shape[axis] < length:
         if array.shape[axis] < length:
             pad_widths = [(0, 0)] * array.ndim
             pad_widths = [(0, 0)] * array.ndim
@@ -89,11 +93,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
         )
         )
     """
     """
     assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
     assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
-    with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
+    with np.load(
+        os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+    ) as f:
         return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
         return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
 
 
 
 
-def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
+def log_mel_spectrogram(
+    audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
+):
     """
     """
     Compute the log-Mel spectrogram of
     Compute the log-Mel spectrogram of
 
 

+ 145 - 53
whisper/decoding.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
 
 
 
 
 @torch.no_grad()
 @torch.no_grad()
-def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
+def detect_language(
+    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
+) -> Tuple[Tensor, List[dict]]:
     """
     """
     Detect the spoken language in the audio, and return them as list of strings, along with the ids
     Detect the spoken language in the audio, and return them as list of strings, along with the ids
     of the most probable language tokens and the probability distribution over all language tokens.
     of the most probable language tokens and the probability distribution over all language tokens.
@@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
     """
     """
     if tokenizer is None:
     if tokenizer is None:
         tokenizer = get_tokenizer(model.is_multilingual)
         tokenizer = get_tokenizer(model.is_multilingual)
-    if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
-        raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
+    if (
+        tokenizer.language is None
+        or tokenizer.language_token not in tokenizer.sot_sequence
+    ):
+        raise ValueError(
+            "This model doesn't have language tokens so it can't perform lang id"
+        )
 
 
     single = mel.ndim == 2
     single = mel.ndim == 2
     if single:
     if single:
@@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
 class DecodingOptions:
 class DecodingOptions:
-    task: str = "transcribe"  # whether to perform X->X "transcribe" or X->English "translate"
-    language: Optional[str] = None  # language that the audio is in; uses detected language if None
+    # whether to perform X->X "transcribe" or X->English "translate"
+    task: str = "transcribe"
+
+    # language that the audio is in; uses detected language if None
+    language: Optional[str] = None
 
 
     # sampling-related options
     # sampling-related options
     temperature: float = 0.0
     temperature: float = 0.0
     sample_len: Optional[int] = None  # maximum number of tokens to sample
     sample_len: Optional[int] = None  # maximum number of tokens to sample
-    best_of: Optional[int] = None     # number of independent samples to collect, when t > 0
-    beam_size: Optional[int] = None   # number of beams in beam search, when t == 0
-    patience: Optional[float] = None  # patience in beam search (https://arxiv.org/abs/2204.05424)
+    best_of: Optional[int] = None  # number of independent sample trajectories, if t > 0
+    beam_size: Optional[int] = None  # number of beams in beam search, if t == 0
+    patience: Optional[float] = None  # patience in beam search (arxiv:2204.05424)
 
 
-    # options for ranking generations (either beams or best-of-N samples)
-    length_penalty: Optional[float] = None   # "alpha" in Google NMT, None defaults to length norm
+    # "alpha" in Google NMT, or None for length norm, when ranking generations
+    # to select which to return among the beams or best-of-N samples
+    length_penalty: Optional[float] = None
 
 
-    # prompt, prefix, and token suppression
-    prompt: Optional[Union[str, List[int]]] = None   # text or tokens for the previous context
-    prefix: Optional[Union[str, List[int]]] = None   # text or tokens to prefix the current context
-    suppress_blank: bool = True                      # this will suppress blank outputs
+    # text or tokens to feed as the prompt or the prefix; for more info:
+    # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
+    prompt: Optional[Union[str, List[int]]] = None  # for the previous context
+    prefix: Optional[Union[str, List[int]]] = None  # to prefix the current context
 
 
     # list of tokens ids (or comma-separated token ids) to suppress
     # list of tokens ids (or comma-separated token ids) to suppress
     # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
     # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
     suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
     suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
+    suppress_blank: bool = True  # this will suppress blank outputs
 
 
     # timestamp sampling options
     # timestamp sampling options
-    without_timestamps: bool = False              # use <|notimestamps|> to sample text tokens only
-    max_initial_timestamp: Optional[float] = 1.0  # the initial timestamp cannot be later than this
+    without_timestamps: bool = False  # use <|notimestamps|> to sample text tokens only
+    max_initial_timestamp: Optional[float] = 1.0
 
 
     # implementation details
     # implementation details
     fp16: bool = True  # use fp16 for most of the calculation
     fp16: bool = True  # use fp16 for most of the calculation
@@ -158,7 +170,9 @@ class PyTorchInference(Inference):
 
 
 
 
 class SequenceRanker:
 class SequenceRanker:
-    def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
+    def rank(
+        self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
+    ) -> List[int]:
         """
         """
         Given a list of groups of samples and their cumulative log probabilities,
         Given a list of groups of samples and their cumulative log probabilities,
         return the indices of the samples in each group to select as the final result
         return the indices of the samples in each group to select as the final result
@@ -196,7 +210,9 @@ class TokenDecoder:
     def reset(self):
     def reset(self):
         """Initialize any stateful variables for decoding a new sequence"""
         """Initialize any stateful variables for decoding a new sequence"""
 
 
-    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+    def update(
+        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
+    ) -> Tuple[Tensor, bool]:
         """Specify how to select the next token, based on the current trace and logits
         """Specify how to select the next token, based on the current trace and logits
 
 
         Parameters
         Parameters
@@ -251,7 +267,9 @@ class GreedyDecoder(TokenDecoder):
         self.temperature = temperature
         self.temperature = temperature
         self.eot = eot
         self.eot = eot
 
 
-    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+    def update(
+        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
+    ) -> Tuple[Tensor, bool]:
         if self.temperature == 0:
         if self.temperature == 0:
             next_tokens = logits.argmax(dim=-1)
             next_tokens = logits.argmax(dim=-1)
         else:
         else:
@@ -274,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
 
 
 
 
 class BeamSearchDecoder(TokenDecoder):
 class BeamSearchDecoder(TokenDecoder):
-    def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
+    def __init__(
+        self,
+        beam_size: int,
+        eot: int,
+        inference: Inference,
+        patience: Optional[float] = None,
+    ):
         self.beam_size = beam_size
         self.beam_size = beam_size
         self.eot = eot
         self.eot = eot
         self.inference = inference
         self.inference = inference
@@ -282,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
         self.max_candidates: int = round(beam_size * self.patience)
         self.max_candidates: int = round(beam_size * self.patience)
         self.finished_sequences = None
         self.finished_sequences = None
 
 
-        assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
+        assert (
+            self.max_candidates > 0
+        ), f"Invalid beam size ({beam_size}) or patience ({patience})"
 
 
     def reset(self):
     def reset(self):
         self.finished_sequences = None
         self.finished_sequences = None
 
 
-    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+    def update(
+        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
+    ) -> Tuple[Tensor, bool]:
         if tokens.shape[0] % self.beam_size != 0:
         if tokens.shape[0] % self.beam_size != 0:
             raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
             raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
 
 
@@ -331,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
 
 
         # add newly finished sequences to self.finished_sequences
         # add newly finished sequences to self.finished_sequences
         assert len(self.finished_sequences) == len(finished_sequences)
         assert len(self.finished_sequences) == len(finished_sequences)
-        for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
+        for previously_finished, newly_finished in zip(
+            self.finished_sequences, finished_sequences
+        ):
             for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
             for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
                 if len(previously_finished) >= self.max_candidates:
                 if len(previously_finished) >= self.max_candidates:
                     break  # the candidate list is full
                     break  # the candidate list is full
@@ -339,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
 
 
         # mark as completed if all audio has enough number of samples
         # mark as completed if all audio has enough number of samples
         completed = all(
         completed = all(
-            len(sequences) >= self.max_candidates for sequences in self.finished_sequences
+            len(sequences) >= self.max_candidates
+            for sequences in self.finished_sequences
         )
         )
         return tokens, completed
         return tokens, completed
 
 
@@ -347,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
         # collect all finished sequences, including patience, and add unfinished ones if not enough
         # collect all finished sequences, including patience, and add unfinished ones if not enough
         sum_logprobs = sum_logprobs.cpu()
         sum_logprobs = sum_logprobs.cpu()
         for i, sequences in enumerate(self.finished_sequences):
         for i, sequences in enumerate(self.finished_sequences):
-            if len(sequences) < self.beam_size:  # when not enough sequences are finished
+            if (
+                len(sequences) < self.beam_size
+            ):  # when not enough sequences are finished
                 for j in list(np.argsort(sum_logprobs[i]))[::-1]:
                 for j in list(np.argsort(sum_logprobs[i]))[::-1]:
                     sequence = preceding_tokens[i, j].tolist() + [self.eot]
                     sequence = preceding_tokens[i, j].tolist() + [self.eot]
                     sequences[tuple(sequence)] = sum_logprobs[i][j].item()
                     sequences[tuple(sequence)] = sum_logprobs[i][j].item()
@@ -355,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
                         break
                         break
 
 
         tokens: List[List[Tensor]] = [
         tokens: List[List[Tensor]] = [
-            [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
+            [torch.tensor(seq) for seq in sequences.keys()]
+            for sequences in self.finished_sequences
         ]
         ]
         sum_logprobs: List[List[float]] = [
         sum_logprobs: List[List[float]] = [
             list(sequences.values()) for sequences in self.finished_sequences
             list(sequences.values()) for sequences in self.finished_sequences
@@ -399,7 +433,10 @@ class SuppressTokens(LogitFilter):
 
 
 class ApplyTimestampRules(LogitFilter):
 class ApplyTimestampRules(LogitFilter):
     def __init__(
     def __init__(
-        self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
+        self,
+        tokenizer: Tokenizer,
+        sample_begin: int,
+        max_initial_timestamp_index: Optional[int],
     ):
     ):
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
         self.sample_begin = sample_begin
         self.sample_begin = sample_begin
@@ -414,8 +451,12 @@ class ApplyTimestampRules(LogitFilter):
         for k in range(tokens.shape[0]):
         for k in range(tokens.shape[0]):
             sampled_tokens = tokens[k, self.sample_begin :]
             sampled_tokens = tokens[k, self.sample_begin :]
             seq = [t for t in sampled_tokens.tolist()]
             seq = [t for t in sampled_tokens.tolist()]
-            last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
-            penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
+            last_was_timestamp = (
+                len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
+            )
+            penultimate_was_timestamp = (
+                len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
+            )
 
 
             if last_was_timestamp:
             if last_was_timestamp:
                 if penultimate_was_timestamp:  # has to be non-timestamp
                 if penultimate_was_timestamp:  # has to be non-timestamp
@@ -423,7 +464,9 @@ class ApplyTimestampRules(LogitFilter):
                 else:  # cannot be normal text tokens
                 else:  # cannot be normal text tokens
                     logits[k, : self.tokenizer.eot] = -np.inf
                     logits[k, : self.tokenizer.eot] = -np.inf
 
 
-            timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
+            timestamps = sampled_tokens[
+                sampled_tokens.ge(self.tokenizer.timestamp_begin)
+            ]
             if timestamps.numel() > 0:
             if timestamps.numel() > 0:
                 # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
                 # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
                 logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
                 logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
@@ -434,13 +477,17 @@ class ApplyTimestampRules(LogitFilter):
 
 
             # apply the `max_initial_timestamp` option
             # apply the `max_initial_timestamp` option
             if self.max_initial_timestamp_index is not None:
             if self.max_initial_timestamp_index is not None:
-                last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+                last_allowed = (
+                    self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+                )
                 logits[:, last_allowed + 1 :] = -np.inf
                 logits[:, last_allowed + 1 :] = -np.inf
 
 
         # if sum of probability over timestamps is above any other token, sample timestamp
         # if sum of probability over timestamps is above any other token, sample timestamp
         logprobs = F.log_softmax(logits.float(), dim=-1)
         logprobs = F.log_softmax(logits.float(), dim=-1)
         for k in range(tokens.shape[0]):
         for k in range(tokens.shape[0]):
-            timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
+            timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
+                dim=-1
+            )
             max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
             max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
             if timestamp_logprob > max_text_token_logprob:
             if timestamp_logprob > max_text_token_logprob:
                 logits[k, : self.tokenizer.timestamp_begin] = -np.inf
                 logits[k, : self.tokenizer.timestamp_begin] = -np.inf
@@ -456,7 +503,9 @@ class DecodingTask:
         self.model = model
         self.model = model
 
 
         language = options.language or "en"
         language = options.language or "en"
-        tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
+        tokenizer = get_tokenizer(
+            model.is_multilingual, language=language, task=options.task
+        )
         self.tokenizer: Tokenizer = tokenizer
         self.tokenizer: Tokenizer = tokenizer
         self.options: DecodingOptions = self._verify_options(options)
         self.options: DecodingOptions = self._verify_options(options)
 
 
@@ -496,9 +545,13 @@ class DecodingTask:
             precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
             precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
             max_initial_timestamp_index = None
             max_initial_timestamp_index = None
             if options.max_initial_timestamp:
             if options.max_initial_timestamp:
-                max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
+                max_initial_timestamp_index = round(
+                    self.options.max_initial_timestamp / precision
+                )
             self.logit_filters.append(
             self.logit_filters.append(
-                ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
+                ApplyTimestampRules(
+                    tokenizer, self.sample_begin, max_initial_timestamp_index
+                )
             )
             )
 
 
     def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
     def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
@@ -509,7 +562,9 @@ class DecodingTask:
                 raise ValueError("best_of with greedy sampling (T=0) is not compatible")
                 raise ValueError("best_of with greedy sampling (T=0) is not compatible")
         if options.patience is not None and options.beam_size is None:
         if options.patience is not None and options.beam_size is None:
             raise ValueError("patience requires beam_size to be given")
             raise ValueError("patience requires beam_size to be given")
-        if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
+        if options.length_penalty is not None and not (
+            0 <= options.length_penalty <= 1
+        ):
             raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
             raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
 
 
         return options
         return options
@@ -519,7 +574,9 @@ class DecodingTask:
 
 
         if prefix := self.options.prefix:
         if prefix := self.options.prefix:
             prefix_tokens = (
             prefix_tokens = (
-                self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
+                self.tokenizer.encode(" " + prefix.strip())
+                if isinstance(prefix, str)
+                else prefix
             )
             )
             if self.sample_len is not None:
             if self.sample_len is not None:
                 max_prefix_len = self.n_ctx // 2 - self.sample_len
                 max_prefix_len = self.n_ctx // 2 - self.sample_len
@@ -528,9 +585,15 @@ class DecodingTask:
 
 
         if prompt := self.options.prompt:
         if prompt := self.options.prompt:
             prompt_tokens = (
             prompt_tokens = (
-                self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
+                self.tokenizer.encode(" " + prompt.strip())
+                if isinstance(prompt, str)
+                else prompt
+            )
+            tokens = (
+                [self.tokenizer.sot_prev]
+                + prompt_tokens[-(self.n_ctx // 2 - 1) :]
+                + tokens
             )
             )
-            tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
 
 
         return tuple(tokens)
         return tuple(tokens)
 
 
@@ -554,7 +617,7 @@ class DecodingTask:
                 self.tokenizer.translate,
                 self.tokenizer.translate,
                 self.tokenizer.sot,
                 self.tokenizer.sot,
                 self.tokenizer.sot_prev,
                 self.tokenizer.sot_prev,
-                self.tokenizer.sot_lm
+                self.tokenizer.sot_lm,
             ]
             ]
         )
         )
         if self.tokenizer.no_speech is not None:
         if self.tokenizer.no_speech is not None:
@@ -567,14 +630,21 @@ class DecodingTask:
         if self.options.fp16:
         if self.options.fp16:
             mel = mel.half()
             mel = mel.half()
 
 
-        if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
+        if mel.shape[-2:] == (
+            self.model.dims.n_audio_ctx,
+            self.model.dims.n_audio_state,
+        ):
             # encoded audio features are given; skip audio encoding
             # encoded audio features are given; skip audio encoding
             audio_features = mel
             audio_features = mel
         else:
         else:
             audio_features = self.model.encoder(mel)
             audio_features = self.model.encoder(mel)
 
 
-        if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
-            return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
+        if audio_features.dtype != (
+            torch.float16 if self.options.fp16 else torch.float32
+        ):
+            return TypeError(
+                f"audio_features has an incorrect dtype: {audio_features.dtype}"
+            )
 
 
         return audio_features
         return audio_features
 
 
@@ -583,7 +653,9 @@ class DecodingTask:
         lang_probs = None
         lang_probs = None
 
 
         if self.options.language is None or self.options.task == "lang_id":
         if self.options.language is None or self.options.task == "lang_id":
-            lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
+            lang_tokens, lang_probs = self.model.detect_language(
+                audio_features, self.tokenizer
+            )
             languages = [max(probs, key=probs.get) for probs in lang_probs]
             languages = [max(probs, key=probs.get) for probs in lang_probs]
             if self.options.language is None:
             if self.options.language is None:
                 tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                 tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
@@ -600,7 +672,9 @@ class DecodingTask:
             for i in range(self.sample_len):
             for i in range(self.sample_len):
                 logits = self.inference.logits(tokens, audio_features)
                 logits = self.inference.logits(tokens, audio_features)
 
 
-                if i == 0 and self.tokenizer.no_speech is not None:  # save no_speech_probs
+                if (
+                    i == 0 and self.tokenizer.no_speech is not None
+                ):  # save no_speech_probs
                     probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                     probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                     no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
                     no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
 
 
@@ -634,8 +708,12 @@ class DecodingTask:
         languages, language_probs = self._detect_language(audio_features, tokens)
         languages, language_probs = self._detect_language(audio_features, tokens)
         if self.options.task == "lang_id":
         if self.options.task == "lang_id":
             return [
             return [
-                DecodingResult(audio_features=features, language=language, language_probs=probs)
-                for features, language, probs in zip(audio_features, languages, language_probs)
+                DecodingResult(
+                    audio_features=features, language=language, language_probs=probs
+                )
+                for features, language, probs in zip(
+                    audio_features, languages, language_probs
+                )
             ]
             ]
 
 
         # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
         # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
@@ -656,7 +734,8 @@ class DecodingTask:
         # get the final candidates for each group, and slice between the first sampled token and EOT
         # get the final candidates for each group, and slice between the first sampled token and EOT
         tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
         tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
         tokens: List[List[Tensor]] = [
         tokens: List[List[Tensor]] = [
-            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
+            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
+            for s in tokens
         ]
         ]
 
 
         # select the top-ranked sample in each group
         # select the top-ranked sample in each group
@@ -665,9 +744,18 @@ class DecodingTask:
         texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
         texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
 
 
         sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
         sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
-        avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
+        avg_logprobs: List[float] = [
+            lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
+        ]
 
 
-        fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
+        fields = (
+            texts,
+            languages,
+            tokens,
+            audio_features,
+            avg_logprobs,
+            no_speech_probs,
+        )
         if len(set(map(len, fields))) != 1:
         if len(set(map(len, fields))) != 1:
             raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
             raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
 
 
@@ -682,12 +770,16 @@ class DecodingTask:
                 temperature=self.options.temperature,
                 temperature=self.options.temperature,
                 compression_ratio=compression_ratio(text),
                 compression_ratio=compression_ratio(text),
             )
             )
-            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
+            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
+                *fields
+            )
         ]
         ]
 
 
 
 
 @torch.no_grad()
 @torch.no_grad()
-def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
+def decode(
+    model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
+) -> Union[DecodingResult, List[DecodingResult]]:
     """
     """
     Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
     Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
 
 

+ 52 - 22
whisper/model.py

@@ -1,16 +1,15 @@
 import base64
 import base64
 import gzip
 import gzip
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Dict
-from typing import Iterable, Optional
+from typing import Dict, Iterable, Optional
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
-from torch import Tensor
-from torch import nn
+from torch import Tensor, nn
 
 
-from .decoding import detect_language as detect_language_function, decode as decode_function
+from .decoding import decode as decode_function
+from .decoding import detect_language as detect_language_function
 from .transcribe import transcribe as transcribe_function
 from .transcribe import transcribe as transcribe_function
 
 
 
 
@@ -36,12 +35,16 @@ class LayerNorm(nn.LayerNorm):
 class Linear(nn.Linear):
 class Linear(nn.Linear):
     def forward(self, x: Tensor) -> Tensor:
     def forward(self, x: Tensor) -> Tensor:
         return F.linear(
         return F.linear(
-            x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
+            x,
+            self.weight.to(x.dtype),
+            None if self.bias is None else self.bias.to(x.dtype),
         )
         )
 
 
 
 
 class Conv1d(nn.Conv1d):
 class Conv1d(nn.Conv1d):
-    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
+    def _conv_forward(
+        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
+    ) -> Tensor:
         return super()._conv_forward(
         return super()._conv_forward(
             x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
             x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
         )
         )
@@ -87,7 +90,9 @@ class MultiHeadAttention(nn.Module):
         wv, qk = self.qkv_attention(q, k, v, mask)
         wv, qk = self.qkv_attention(q, k, v, mask)
         return self.out(wv), qk
         return self.out(wv), qk
 
 
-    def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
+    def qkv_attention(
+        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+    ):
         n_batch, n_ctx, n_state = q.shape
         n_batch, n_ctx, n_state = q.shape
         scale = (n_state // self.n_head) ** -0.25
         scale = (n_state // self.n_head) ** -0.25
         q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
         q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -110,11 +115,15 @@ class ResidualAttentionBlock(nn.Module):
         self.attn = MultiHeadAttention(n_state, n_head)
         self.attn = MultiHeadAttention(n_state, n_head)
         self.attn_ln = LayerNorm(n_state)
         self.attn_ln = LayerNorm(n_state)
 
 
-        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
+        self.cross_attn = (
+            MultiHeadAttention(n_state, n_head) if cross_attention else None
+        )
         self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
         self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
 
 
         n_mlp = n_state * 4
         n_mlp = n_state * 4
-        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
+        self.mlp = nn.Sequential(
+            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+        )
         self.mlp_ln = LayerNorm(n_state)
         self.mlp_ln = LayerNorm(n_state)
 
 
     def forward(
     def forward(
@@ -132,7 +141,9 @@ class ResidualAttentionBlock(nn.Module):
 
 
 
 
 class AudioEncoder(nn.Module):
 class AudioEncoder(nn.Module):
-    def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+    def __init__(
+        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+    ):
         super().__init__()
         super().__init__()
         self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
         self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
         self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
         self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@@ -163,14 +174,19 @@ class AudioEncoder(nn.Module):
 
 
 
 
 class TextDecoder(nn.Module):
 class TextDecoder(nn.Module):
-    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+    def __init__(
+        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+    ):
         super().__init__()
         super().__init__()
 
 
         self.token_embedding = nn.Embedding(n_vocab, n_state)
         self.token_embedding = nn.Embedding(n_vocab, n_state)
         self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
         self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
 
 
         self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
         self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
-            [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
+            [
+                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
+                for _ in range(n_layer)
+            ]
         )
         )
         self.ln = LayerNorm(n_state)
         self.ln = LayerNorm(n_state)
 
 
@@ -185,14 +201,19 @@ class TextDecoder(nn.Module):
             the encoded audio features to be attended on
             the encoded audio features to be attended on
         """
         """
         offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
         offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
-        x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
+        x = (
+            self.token_embedding(x)
+            + self.positional_embedding[offset : offset + x.shape[-1]]
+        )
         x = x.to(xa.dtype)
         x = x.to(xa.dtype)
 
 
         for block in self.blocks:
         for block in self.blocks:
             x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
             x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
 
 
         x = self.ln(x)
         x = self.ln(x)
-        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
+        logits = (
+            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+        ).float()
 
 
         return logits
         return logits
 
 
@@ -216,13 +237,19 @@ class Whisper(nn.Module):
             self.dims.n_text_layer,
             self.dims.n_text_layer,
         )
         )
         # use the last half layers for alignment by default; see `set_alignment_heads()` below
         # use the last half layers for alignment by default; see `set_alignment_heads()` below
-        all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
-        all_heads[self.dims.n_text_layer // 2:] = True
+        all_heads = torch.zeros(
+            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
+        )
+        all_heads[self.dims.n_text_layer // 2 :] = True
         self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
         self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
 
 
     def set_alignment_heads(self, dump: bytes):
     def set_alignment_heads(self, dump: bytes):
-        array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
-        mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
+        array = np.frombuffer(
+            gzip.decompress(base64.b85decode(dump)), dtype=bool
+        ).copy()
+        mask = torch.from_numpy(array).reshape(
+            self.dims.n_text_layer, self.dims.n_text_head
+        )
         self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
         self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
 
 
     def embed_audio(self, mel: torch.Tensor):
     def embed_audio(self, mel: torch.Tensor):
@@ -231,7 +258,9 @@ class Whisper(nn.Module):
     def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
     def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
         return self.decoder(tokens, audio_features)
         return self.decoder(tokens, audio_features)
 
 
-    def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
+    def forward(
+        self, mel: torch.Tensor, tokens: torch.Tensor
+    ) -> Dict[str, torch.Tensor]:
         return self.decoder(tokens, self.encoder(mel))
         return self.decoder(tokens, self.encoder(mel))
 
 
     @property
     @property
@@ -260,8 +289,9 @@ class Whisper(nn.Module):
         hooks = []
         hooks = []
 
 
         def save_to_cache(module, _, output):
         def save_to_cache(module, _, output):
-            if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
-                cache[module] = output  # save as-is, for the first token or cross attention
+            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
+                # save as-is, for the first token or cross attention
+                cache[module] = output
             else:
             else:
                 cache[module] = torch.cat([cache[module], output], dim=1).detach()
                 cache[module] = torch.cat([cache[module], output], dim=1).detach()
             return cache[module]
             return cache[module]

+ 2 - 2
whisper/normalizers/__init__.py

@@ -1,2 +1,2 @@
-from .basic import BasicTextNormalizer
-from .english import EnglishTextNormalizer
+from .basic import BasicTextNormalizer as BasicTextNormalizer
+from .english import EnglishTextNormalizer as EnglishTextNormalizer

+ 8 - 3
whisper/normalizers/basic.py

@@ -48,13 +48,16 @@ def remove_symbols(s: str):
     Replace any other markers, symbols, punctuations with a space, keeping diacritics
     Replace any other markers, symbols, punctuations with a space, keeping diacritics
     """
     """
     return "".join(
     return "".join(
-        " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
+        " " if unicodedata.category(c)[0] in "MSP" else c
+        for c in unicodedata.normalize("NFKC", s)
     )
     )
 
 
 
 
 class BasicTextNormalizer:
 class BasicTextNormalizer:
     def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
     def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
-        self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
+        self.clean = (
+            remove_symbols_and_diacritics if remove_diacritics else remove_symbols
+        )
         self.split_letters = split_letters
         self.split_letters = split_letters
 
 
     def __call__(self, s: str):
     def __call__(self, s: str):
@@ -66,6 +69,8 @@ class BasicTextNormalizer:
         if self.split_letters:
         if self.split_letters:
             s = " ".join(regex.findall(r"\X", s, regex.U))
             s = " ".join(regex.findall(r"\X", s, regex.U))
 
 
-        s = re.sub(r"\s+", " ", s)  # replace any successive whitespace characters with a space
+        s = re.sub(
+            r"\s+", " ", s
+        )  # replace any successive whitespace characters with a space
 
 
         return s
         return s

+ 14 - 7
whisper/normalizers/english.py

@@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
             name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
             name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
         }
         }
         self.tens_ordinal = {
         self.tens_ordinal = {
-            name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
+            name.replace("y", "ieth"): (value, "th")
+            for name, value in self.tens.items()
         }
         }
         self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
         self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
 
 
@@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
         self.multipliers_ordinal = {
         self.multipliers_ordinal = {
             name + "th": (value, "th") for name, value in self.multipliers.items()
             name + "th": (value, "th") for name, value in self.multipliers.items()
         }
         }
-        self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
+        self.multipliers_suffixed = {
+            **self.multipliers_plural,
+            **self.multipliers_ordinal,
+        }
         self.decimals = {*self.ones, *self.tens, *self.zeros}
         self.decimals = {*self.ones, *self.tens, *self.zeros}
 
 
         self.preceding_prefixers = {
         self.preceding_prefixers = {
@@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
             "cents": "¢",
             "cents": "¢",
         }
         }
         self.prefixes = set(
         self.prefixes = set(
-            list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
+            list(self.preceding_prefixers.values())
+            + list(self.following_prefixers.values())
         )
         )
         self.suffixers = {
         self.suffixers = {
             "per": {"cent": "%"},
             "per": {"cent": "%"},
@@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
                 if value is None:
                 if value is None:
                     value = ones
                     value = ones
                 elif isinstance(value, str) or prev in self.ones:
                 elif isinstance(value, str) or prev in self.ones:
-                    if prev in self.tens and ones < 10:  # replace the last zero with the digit
+                    if (
+                        prev in self.tens and ones < 10
+                    ):  # replace the last zero with the digit
                         assert value[-1] == "0"
                         assert value[-1] == "0"
                         value = value[:-1] + str(ones)
                         value = value[:-1] + str(ones)
                     else:
                     else:
@@ -522,14 +529,14 @@ class EnglishTextNormalizer:
         s = re.sub(r"[<\[][^>\]]*[>\]]", "", s)  # remove words between brackets
         s = re.sub(r"[<\[][^>\]]*[>\]]", "", s)  # remove words between brackets
         s = re.sub(r"\(([^)]+?)\)", "", s)  # remove words between parenthesis
         s = re.sub(r"\(([^)]+?)\)", "", s)  # remove words between parenthesis
         s = re.sub(self.ignore_patterns, "", s)
         s = re.sub(self.ignore_patterns, "", s)
-        s = re.sub(r"\s+'", "'", s)  # standardize when there's a space before an apostrophe
+        s = re.sub(r"\s+'", "'", s)  # when there's a space before an apostrophe
 
 
         for pattern, replacement in self.replacers.items():
         for pattern, replacement in self.replacers.items():
             s = re.sub(pattern, replacement, s)
             s = re.sub(pattern, replacement, s)
 
 
         s = re.sub(r"(\d),(\d)", r"\1\2", s)  # remove commas between digits
         s = re.sub(r"(\d),(\d)", r"\1\2", s)  # remove commas between digits
         s = re.sub(r"\.([^0-9]|$)", r" \1", s)  # remove periods not followed by numbers
         s = re.sub(r"\.([^0-9]|$)", r" \1", s)  # remove periods not followed by numbers
-        s = remove_symbols_and_diacritics(s, keep=".%$¢€£")  # keep some symbols for numerics
+        s = remove_symbols_and_diacritics(s, keep=".%$¢€£")  # keep numeric symbols
 
 
         s = self.standardize_numbers(s)
         s = self.standardize_numbers(s)
         s = self.standardize_spellings(s)
         s = self.standardize_spellings(s)
@@ -538,6 +545,6 @@ class EnglishTextNormalizer:
         s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
         s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
         s = re.sub(r"([^0-9])%", r"\1 ", s)
         s = re.sub(r"([^0-9])%", r"\1 ", s)
 
 
-        s = re.sub(r"\s+", " ", s)  # replace any successive whitespace characters with a space
+        s = re.sub(r"\s+", " ", s)  # replace any successive whitespaces with a space
 
 
         return s
         return s

+ 34 - 16
whisper/timing.py

@@ -1,7 +1,7 @@
 import subprocess
 import subprocess
 import warnings
 import warnings
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import List, TYPE_CHECKING
+from typing import TYPE_CHECKING, List
 
 
 import numba
 import numba
 import numpy as np
 import numpy as np
@@ -26,13 +26,16 @@ def median_filter(x: torch.Tensor, filter_width: int):
         # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
         # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
         x = x[None, None, :]
         x = x[None, None, :]
 
 
-    assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
+    assert (
+        filter_width > 0 and filter_width % 2 == 1
+    ), "`filter_width` should be an odd number"
 
 
     result = None
     result = None
     x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
     x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
     if x.is_cuda:
     if x.is_cuda:
         try:
         try:
             from .triton_ops import median_filter_cuda
             from .triton_ops import median_filter_cuda
+
             result = median_filter_cuda(x, filter_width)
             result = median_filter_cuda(x, filter_width)
         except (RuntimeError, subprocess.CalledProcessError):
         except (RuntimeError, subprocess.CalledProcessError):
             warnings.warn(
             warnings.warn(
@@ -49,6 +52,7 @@ def median_filter(x: torch.Tensor, filter_width: int):
 
 
     return result
     return result
 
 
+
 @numba.jit
 @numba.jit
 def backtrace(trace: np.ndarray):
 def backtrace(trace: np.ndarray):
     i = trace.shape[0] - 1
     i = trace.shape[0] - 1
@@ -106,7 +110,9 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
     M, N = x.shape
     M, N = x.shape
     assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
     assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
 
 
-    x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
+    x_skew = (
+        F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
+    )
     x_skew = x_skew.T.contiguous()
     x_skew = x_skew.T.contiguous()
     cost = torch.ones(N + M + 2, M + 2) * np.inf
     cost = torch.ones(N + M + 2, M + 2) * np.inf
     cost[0, 0] = 0
     cost[0, 0] = 0
@@ -122,10 +128,12 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
         trace.stride(0),
         trace.stride(0),
         N,
         N,
         M,
         M,
-        BLOCK_SIZE=BLOCK_SIZE
+        BLOCK_SIZE=BLOCK_SIZE,
     )
     )
 
 
-    trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
+    trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
+        :, : N + 1
+    ]
     return backtrace(trace.cpu().numpy())
     return backtrace(trace.cpu().numpy())
 
 
 
 
@@ -181,8 +189,10 @@ def find_alignment(
 
 
     with torch.no_grad():
     with torch.no_grad():
         logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
         logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
-        token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
-        text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
+        sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
+        token_probs = sampled_logits.softmax(dim=-1)
+        text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
+        text_token_probs = text_token_probs.tolist()
 
 
     for hook in hooks:
     for hook in hooks:
         hook.remove()
         hook.remove()
@@ -196,7 +206,7 @@ def find_alignment(
     weights = median_filter(weights, medfilt_width)
     weights = median_filter(weights, medfilt_width)
 
 
     matrix = weights.mean(axis=0)
     matrix = weights.mean(axis=0)
-    matrix = matrix[len(tokenizer.sot_sequence):-1]
+    matrix = matrix[len(tokenizer.sot_sequence) : -1]
     text_indices, time_indices = dtw(-matrix)
     text_indices, time_indices = dtw(-matrix)
 
 
     words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
     words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
@@ -207,7 +217,8 @@ def find_alignment(
     start_times = jump_times[word_boundaries[:-1]]
     start_times = jump_times[word_boundaries[:-1]]
     end_times = jump_times[word_boundaries[1:]]
     end_times = jump_times[word_boundaries[1:]]
     word_probabilities = [
     word_probabilities = [
-        np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
+        np.mean(text_token_probs[i:j])
+        for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
     ]
     ]
 
 
     # hack: ensure the first and second word is not longer than twice the median word duration.
     # hack: ensure the first and second word is not longer than twice the median word duration.
@@ -218,7 +229,8 @@ def find_alignment(
         median_duration = np.median(word_durations)
         median_duration = np.median(word_durations)
         max_duration = median_duration * 2
         max_duration = median_duration * 2
         if len(word_durations) >= 2 and word_durations[1] > max_duration:
         if len(word_durations) >= 2 and word_durations[1] > max_duration:
-            end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
+            boundary = max(end_times[2] / 2, end_times[2] - max_duration)
+            end_times[0] = start_times[1] = boundary
         if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
         if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
             start_times[0] = max(0, end_times[0] - max_duration)
             start_times[0] = max(0, end_times[0] - max_duration)
 
 
@@ -271,19 +283,20 @@ def add_word_timestamps(
     tokenizer: Tokenizer,
     tokenizer: Tokenizer,
     mel: torch.Tensor,
     mel: torch.Tensor,
     num_frames: int,
     num_frames: int,
-    prepend_punctuations: str = "\"\'“¿([{-",
-    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
-    **hyperparams,
+    prepend_punctuations: str = "\"'“¿([{-",
+    append_punctuations: str = "\"'.。,,!!??::”)]}、",
+    **kwargs,
 ):
 ):
     if len(segments) == 0:
     if len(segments) == 0:
         return
         return
 
 
     text_tokens = [t for segment in segments for t in segment["tokens"]]
     text_tokens = [t for segment in segments for t in segment["tokens"]]
-    alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
+    alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
 
 
     time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
     time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
-    token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
+    segment_lengths = [len(s["tokens"]) for s in segments]
+    token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
 
 
     for segment in segments:
     for segment in segments:
         segment["words"] = []
         segment["words"] = []
@@ -295,7 +308,12 @@ def add_word_timestamps(
             start = round(time_offset + timing.start, 2)
             start = round(time_offset + timing.start, 2)
             end = round(time_offset + timing.end, 2)
             end = round(time_offset + timing.end, 2)
             segment["words"].append(
             segment["words"].append(
-                dict(word=timing.word, start=start, end=end, probability=timing.probability)
+                dict(
+                    word=timing.word,
+                    start=start,
+                    end=end,
+                    probability=timing.probability,
+                )
             )
             )
 
 
     for segment in segments:
     for segment in segments:

+ 19 - 9
whisper/tokenizer.py

@@ -1,7 +1,7 @@
 import os
 import os
 import string
 import string
 from dataclasses import dataclass
 from dataclasses import dataclass
-from functools import lru_cache, cached_property
+from functools import cached_property, lru_cache
 from typing import List, Optional, Tuple, Union
 from typing import List, Optional, Tuple, Union
 
 
 import numpy as np
 import numpy as np
@@ -138,7 +138,9 @@ class Tokenizer:
     def encode(self, text, **kwargs):
     def encode(self, text, **kwargs):
         return self.tokenizer.encode(text, **kwargs)
         return self.tokenizer.encode(text, **kwargs)
 
 
-    def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
+    def decode(
+        self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
+    ):
         return self.tokenizer.decode(token_ids, **kwargs)
         return self.tokenizer.decode(token_ids, **kwargs)
 
 
     def decode_with_timestamps(self, tokens) -> str:
     def decode_with_timestamps(self, tokens) -> str:
@@ -154,8 +156,9 @@ class Tokenizer:
                 outputs.append([])
                 outputs.append([])
             else:
             else:
                 outputs[-1].append(token)
                 outputs[-1].append(token)
-        outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
-        return "".join(outputs)
+        return "".join(
+            [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
+        )
 
 
     @cached_property
     @cached_property
     def eot(self) -> int:
     def eot(self) -> int:
@@ -197,7 +200,7 @@ class Tokenizer:
     def language_token(self) -> int:
     def language_token(self) -> int:
         """Returns the token id corresponding to the value of the `language` field"""
         """Returns the token id corresponding to the value of the `language` field"""
         if self.language is None:
         if self.language is None:
-            raise ValueError(f"This tokenizer does not have language token configured")
+            raise ValueError("This tokenizer does not have language token configured")
 
 
         additional_tokens = dict(
         additional_tokens = dict(
             zip(
             zip(
@@ -242,8 +245,10 @@ class Tokenizer:
 
 
         keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
         keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
         """
         """
-        symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
-        symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
+        symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
+        symbols += (
+            "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
+        )
 
 
         # symbols that may be a single token or multiple tokens depending on the tokenizer.
         # symbols that may be a single token or multiple tokens depending on the tokenizer.
         # In case they're multiple tokens, suppress the first token, which is safe because:
         # In case they're multiple tokens, suppress the first token, which is safe because:
@@ -255,7 +260,10 @@ class Tokenizer:
         # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
         # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
         result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
         result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
         for symbol in symbols + list(miscellaneous):
         for symbol in symbols + list(miscellaneous):
-            for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
+            for tokens in [
+                self.tokenizer.encode(symbol),
+                self.tokenizer.encode(" " + symbol),
+            ]:
                 if len(tokens) == 1 or symbol in miscellaneous:
                 if len(tokens) == 1 or symbol in miscellaneous:
                     result.add(tokens[0])
                     result.add(tokens[0])
 
 
@@ -367,4 +375,6 @@ def get_tokenizer(
     if task is not None:
     if task is not None:
         sot_sequence.append(transcribe if task == "transcribe" else translate)
         sot_sequence.append(transcribe if task == "transcribe" else translate)
 
 
-    return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
+    return Tokenizer(
+        tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
+    )

+ 104 - 39
whisper/transcribe.py

@@ -1,17 +1,32 @@
 import argparse
 import argparse
 import os
 import os
 import warnings
 import warnings
-from typing import Optional, Tuple, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple, Union
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 import tqdm
 import tqdm
 
 
-from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim
+from .audio import (
+    FRAMES_PER_SECOND,
+    HOP_LENGTH,
+    N_FRAMES,
+    SAMPLE_RATE,
+    log_mel_spectrogram,
+    pad_or_trim,
+)
 from .decoding import DecodingOptions, DecodingResult
 from .decoding import DecodingOptions, DecodingResult
 from .timing import add_word_timestamps
 from .timing import add_word_timestamps
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
-from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
+from .utils import (
+    exact_div,
+    format_timestamp,
+    get_writer,
+    make_safe,
+    optional_float,
+    optional_int,
+    str2bool,
+)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .model import Whisper
     from .model import Whisper
@@ -29,8 +44,8 @@ def transcribe(
     condition_on_previous_text: bool = True,
     condition_on_previous_text: bool = True,
     initial_prompt: Optional[str] = None,
     initial_prompt: Optional[str] = None,
     word_timestamps: bool = False,
     word_timestamps: bool = False,
-    prepend_punctuations: str = "\"\'“¿([{-",
-    append_punctuations: str = "\"\'.。,,!!??::”)]}、",
+    prepend_punctuations: str = "\"'“¿([{-",
+    append_punctuations: str = "\"'.。,,!!??::”)]}、",
     **decode_options,
     **decode_options,
 ):
 ):
     """
     """
@@ -108,12 +123,16 @@ def transcribe(
             decode_options["language"] = "en"
             decode_options["language"] = "en"
         else:
         else:
             if verbose:
             if verbose:
-                print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
+                print(
+                    "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
+                )
             mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
             mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
             _, probs = model.detect_language(mel_segment)
             _, probs = model.detect_language(mel_segment)
             decode_options["language"] = max(probs, key=probs.get)
             decode_options["language"] = max(probs, key=probs.get)
             if verbose is not None:
             if verbose is not None:
-                print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
+                print(
+                    f"Detected language: {LANGUAGES[decode_options['language']].title()}"
+                )
 
 
     language: str = decode_options["language"]
     language: str = decode_options["language"]
     task: str = decode_options.get("task", "transcribe")
     task: str = decode_options.get("task", "transcribe")
@@ -123,7 +142,9 @@ def transcribe(
         warnings.warn("Word-level timestamps on translations may not be reliable.")
         warnings.warn("Word-level timestamps on translations may not be reliable.")
 
 
     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
-        temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
+        temperatures = (
+            [temperature] if isinstance(temperature, (int, float)) else temperature
+        )
         decode_result = None
         decode_result = None
 
 
         for t in temperatures:
         for t in temperatures:
@@ -140,9 +161,15 @@ def transcribe(
             decode_result = model.decode(segment, options)
             decode_result = model.decode(segment, options)
 
 
             needs_fallback = False
             needs_fallback = False
-            if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
+            if (
+                compression_ratio_threshold is not None
+                and decode_result.compression_ratio > compression_ratio_threshold
+            ):
                 needs_fallback = True  # too repetitive
                 needs_fallback = True  # too repetitive
-            if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
+            if (
+                logprob_threshold is not None
+                and decode_result.avg_logprob < logprob_threshold
+            ):
                 needs_fallback = True  # average log probability is too low
                 needs_fallback = True  # average log probability is too low
 
 
             if not needs_fallback:
             if not needs_fallback:
@@ -186,7 +213,9 @@ def transcribe(
 
 
     # show the progress bar when verbose is False (if True, transcribed text will be printed)
     # show the progress bar when verbose is False (if True, transcribed text will be printed)
     num_frames = mel.shape[-1]
     num_frames = mel.shape[-1]
-    with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
+    with tqdm.tqdm(
+        total=num_frames, unit="frames", disable=verbose is not False
+    ) as pbar:
         while seek < num_frames:
         while seek < num_frames:
             time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
             time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
             mel_segment = mel[:, seek:]
             mel_segment = mel[:, seek:]
@@ -201,7 +230,10 @@ def transcribe(
             if no_speech_threshold is not None:
             if no_speech_threshold is not None:
                 # no voice activity check
                 # no voice activity check
                 should_skip = result.no_speech_prob > no_speech_threshold
                 should_skip = result.no_speech_prob > no_speech_threshold
-                if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
+                if (
+                    logprob_threshold is not None
+                    and result.avg_logprob > logprob_threshold
+                ):
                     # don't skip if the logprob is high enough, despite the no_speech_prob
                     # don't skip if the logprob is high enough, despite the no_speech_prob
                     should_skip = False
                     should_skip = False
 
 
@@ -214,22 +246,35 @@ def transcribe(
             current_tokens = []
             current_tokens = []
 
 
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
-            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
-            if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
-                if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]:
+            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
+                0
+            ].add_(1)
+            if (
+                len(consecutive) > 0
+            ):  # if the output contains two consecutive timestamp tokens
+                if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
+                    False,
+                    True,
+                ]:
                     consecutive = consecutive.tolist() + [len(tokens)]
                     consecutive = consecutive.tolist() + [len(tokens)]
 
 
                 last_slice = 0
                 last_slice = 0
                 for current_slice in consecutive:
                 for current_slice in consecutive:
                     sliced_tokens = tokens[last_slice:current_slice]
                     sliced_tokens = tokens[last_slice:current_slice]
-                    start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
-                    end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
-                    current_segments.append(new_segment(
-                        start=time_offset + start_timestamp_pos * time_precision,
-                        end=time_offset + end_timestamp_pos * time_precision,
-                        tokens=sliced_tokens,
-                        result=result,
-                    ))
+                    start_timestamp_pos = (
+                        sliced_tokens[0].item() - tokenizer.timestamp_begin
+                    )
+                    end_timestamp_pos = (
+                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
+                    )
+                    current_segments.append(
+                        new_segment(
+                            start=time_offset + start_timestamp_pos * time_precision,
+                            end=time_offset + end_timestamp_pos * time_precision,
+                            tokens=sliced_tokens,
+                            result=result,
+                        )
+                    )
                     current_tokens.append(sliced_tokens.tolist())
                     current_tokens.append(sliced_tokens.tolist())
                     last_slice = current_slice
                     last_slice = current_slice
 
 
@@ -238,23 +283,32 @@ def transcribe(
                     seek += segment_size
                     seek += segment_size
                 else:
                 else:
                     # otherwise, ignore the unfinished segment and seek to the last timestamp
                     # otherwise, ignore the unfinished segment and seek to the last timestamp
-                    last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
+                    last_timestamp_pos = (
+                        tokens[last_slice - 1].item() - tokenizer.timestamp_begin
+                    )
                     seek += last_timestamp_pos * input_stride
                     seek += last_timestamp_pos * input_stride
                 all_tokens.extend(tokens[: last_slice + 1].tolist())
                 all_tokens.extend(tokens[: last_slice + 1].tolist())
             else:
             else:
                 duration = segment_duration
                 duration = segment_duration
                 timestamps = tokens[timestamp_tokens.nonzero().flatten()]
                 timestamps = tokens[timestamp_tokens.nonzero().flatten()]
-                if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
+                if (
+                    len(timestamps) > 0
+                    and timestamps[-1].item() != tokenizer.timestamp_begin
+                ):
                     # no consecutive timestamps but it has a timestamp; use the last one.
                     # no consecutive timestamps but it has a timestamp; use the last one.
-                    last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
+                    last_timestamp_pos = (
+                        timestamps[-1].item() - tokenizer.timestamp_begin
+                    )
                     duration = last_timestamp_pos * time_precision
                     duration = last_timestamp_pos * time_precision
 
 
-                current_segments.append(new_segment(
-                    start=time_offset,
-                    end=time_offset + duration,
-                    tokens=tokens,
-                    result=result,
-                ))
+                current_segments.append(
+                    new_segment(
+                        start=time_offset,
+                        end=time_offset + duration,
+                        tokens=tokens,
+                        result=result,
+                    )
+                )
                 current_tokens.append(tokens.tolist())
                 current_tokens.append(tokens.tolist())
                 seek += segment_size
                 seek += segment_size
 
 
@@ -272,9 +326,13 @@ def transcribe(
                     prepend_punctuations=prepend_punctuations,
                     prepend_punctuations=prepend_punctuations,
                     append_punctuations=append_punctuations,
                     append_punctuations=append_punctuations,
                 )
                 )
-                word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]]
+                word_end_timestamps = [
+                    w["end"] for s in current_segments for w in s["words"]
+                ]
                 if len(consecutive) > 0 and len(word_end_timestamps) > 0:
                 if len(consecutive) > 0 and len(word_end_timestamps) > 0:
-                    seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND)
+                    seek_shift = round(
+                        (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
+                    )
                     if seek_shift > 0:
                     if seek_shift > 0:
                         seek = previous_seek + seek_shift
                         seek = previous_seek + seek_shift
 
 
@@ -293,21 +351,24 @@ def transcribe(
                     current_tokens[i] = []
                     current_tokens[i] = []
 
 
             all_segments.extend(current_segments)
             all_segments.extend(current_segments)
-            all_tokens.extend([token for segment in current_tokens for token in segment])
+            all_tokens.extend(
+                [token for segment in current_tokens for token in segment]
+            )
 
 
             # update progress bar
             # update progress bar
             pbar.update(min(num_frames, seek) - previous_seek)
             pbar.update(min(num_frames, seek) - previous_seek)
 
 
     return dict(
     return dict(
-        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
+        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
         segments=all_segments,
         segments=all_segments,
-        language=language
+        language=language,
     )
     )
 
 
 
 
 def cli():
 def cli():
     from . import available_models
     from . import available_models
 
 
+    # fmt: off
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
     parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
     parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
     parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
@@ -339,6 +400,7 @@ def cli():
     parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
     parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
     parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
     parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
     parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
     parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
+    # fmt: on
 
 
     args = parser.parse_args().__dict__
     args = parser.parse_args().__dict__
     model_name: str = args.pop("model")
     model_name: str = args.pop("model")
@@ -350,7 +412,9 @@ def cli():
 
 
     if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
     if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
         if args["language"] is not None:
         if args["language"] is not None:
-            warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
+            warnings.warn(
+                f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
+            )
         args["language"] = "en"
         args["language"] = "en"
 
 
     temperature = args.pop("temperature")
     temperature = args.pop("temperature")
@@ -363,6 +427,7 @@ def cli():
         torch.set_num_threads(threads)
         torch.set_num_threads(threads)
 
 
     from . import load_model
     from . import load_model
+
     model = load_model(model_name, device=device, download_root=model_dir)
     model = load_model(model_name, device=device, download_root=model_dir)
 
 
     writer = get_writer(output_format, output_dir)
     writer = get_writer(output_format, output_dir)
@@ -371,5 +436,5 @@ def cli():
         writer(result, audio_path)
         writer(result, audio_path)
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     cli()
     cli()

+ 41 - 24
whisper/triton_ops.py

@@ -1,8 +1,7 @@
-import math
+from functools import lru_cache
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
-from functools import lru_cache
 
 
 try:
 try:
     import triton
     import triton
@@ -12,7 +11,9 @@ except ImportError:
 
 
 
 
 @triton.jit
 @triton.jit
-def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
+def dtw_kernel(
+    cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
+):
     offsets = tl.arange(0, BLOCK_SIZE)
     offsets = tl.arange(0, BLOCK_SIZE)
     mask = offsets < M
     mask = offsets < M
 
 
@@ -42,37 +43,53 @@ def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_
 @lru_cache(maxsize=None)
 @lru_cache(maxsize=None)
 def median_kernel(filter_width: int):
 def median_kernel(filter_width: int):
     @triton.jit
     @triton.jit
-    def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr):  # x.shape[-1] == filter_width
+    def kernel(
+        y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
+    ):  # x.shape[-1] == filter_width
         row_idx = tl.program_id(0)
         row_idx = tl.program_id(0)
         offsets = tl.arange(0, BLOCK_SIZE)
         offsets = tl.arange(0, BLOCK_SIZE)
         mask = offsets < y_stride
         mask = offsets < y_stride
 
 
-        x_ptr = x + row_idx * x_stride
+        x_ptr = x + row_idx * x_stride  # noqa: F841
         y_ptr = y + row_idx * y_stride
         y_ptr = y + row_idx * y_stride
 
 
-        LOAD_ALL_ROWS_HERE
+        LOAD_ALL_ROWS_HERE  # noqa: F821
 
 
-        BUBBLESORT_HERE
+        BUBBLESORT_HERE  # noqa: F821
 
 
-        tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)
+        tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)  # noqa: F821
 
 
     kernel = triton.JITFunction(kernel.fn)
     kernel = triton.JITFunction(kernel.fn)
-    kernel.src = kernel.src.replace("    LOAD_ALL_ROWS_HERE", "\n".join([
-        f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
-        for i in range(filter_width)
-    ]))
-    kernel.src = kernel.src.replace("    BUBBLESORT_HERE", "\n\n".join([
-        "\n\n".join([
-            "\n".join([
-                f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
-                f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
-                f"    row{j} = smaller",
-                f"    row{j + 1} = larger",
-            ])
-            for j in range(filter_width - i - 1)
-        ])
-        for i in range(filter_width // 2 + 1)
-    ]))
+    kernel.src = kernel.src.replace(
+        "    LOAD_ALL_ROWS_HERE",
+        "\n".join(
+            [
+                f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
+                for i in range(filter_width)
+            ]
+        ),
+    )
+    kernel.src = kernel.src.replace(
+        "    BUBBLESORT_HERE",
+        "\n\n".join(
+            [
+                "\n\n".join(
+                    [
+                        "\n".join(
+                            [
+                                f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
+                                f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
+                                f"    row{j} = smaller",
+                                f"    row{j + 1} = larger",
+                            ]
+                        )
+                        for j in range(filter_width - i - 1)
+                    ]
+                )
+                for i in range(filter_width // 2 + 1)
+            ]
+        ),
+    )
     kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
     kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
 
 
     return kernel
     return kernel

+ 24 - 12
whisper/utils.py

@@ -7,11 +7,14 @@ from typing import Callable, TextIO
 system_encoding = sys.getdefaultencoding()
 system_encoding = sys.getdefaultencoding()
 
 
 if system_encoding != "utf-8":
 if system_encoding != "utf-8":
+
     def make_safe(string):
     def make_safe(string):
         # replaces any character not representable using the system default encoding with an '?',
         # replaces any character not representable using the system default encoding with an '?',
         # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
         # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
         return string.encode(system_encoding, errors="replace").decode(system_encoding)
         return string.encode(system_encoding, errors="replace").decode(system_encoding)
+
 else:
 else:
+
     def make_safe(string):
     def make_safe(string):
         # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
         # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
         return string
         return string
@@ -43,7 +46,9 @@ def compression_ratio(text) -> float:
     return len(text_bytes) / len(zlib.compress(text_bytes))
     return len(text_bytes) / len(zlib.compress(text_bytes))
 
 
 
 
-def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
+def format_timestamp(
+    seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
+):
     assert seconds >= 0, "non-negative timestamp expected"
     assert seconds >= 0, "non-negative timestamp expected"
     milliseconds = round(seconds * 1000.0)
     milliseconds = round(seconds * 1000.0)
 
 
@@ -57,7 +62,9 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
     milliseconds -= seconds * 1_000
     milliseconds -= seconds * 1_000
 
 
     hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
     hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
-    return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+    return (
+        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+    )
 
 
 
 
 class ResultWriter:
 class ResultWriter:
@@ -68,7 +75,9 @@ class ResultWriter:
 
 
     def __call__(self, result: dict, audio_path: str):
     def __call__(self, result: dict, audio_path: str):
         audio_basename = os.path.basename(audio_path)
         audio_basename = os.path.basename(audio_path)
-        output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
+        output_path = os.path.join(
+            self.output_dir, audio_basename + "." + self.extension
+        )
 
 
         with open(output_path, "w", encoding="utf-8") as f:
         with open(output_path, "w", encoding="utf-8") as f:
             self.write_result(result, file=f)
             self.write_result(result, file=f)
@@ -82,7 +91,7 @@ class WriteTXT(ResultWriter):
 
 
     def write_result(self, result: dict, file: TextIO):
     def write_result(self, result: dict, file: TextIO):
         for segment in result["segments"]:
         for segment in result["segments"]:
-            print(segment['text'].strip(), file=file, flush=True)
+            print(segment["text"].strip(), file=file, flush=True)
 
 
 
 
 class SubtitlesWriter(ResultWriter):
 class SubtitlesWriter(ResultWriter):
@@ -93,7 +102,7 @@ class SubtitlesWriter(ResultWriter):
         for segment in result["segments"]:
         for segment in result["segments"]:
             segment_start = self.format_timestamp(segment["start"])
             segment_start = self.format_timestamp(segment["start"])
             segment_end = self.format_timestamp(segment["end"])
             segment_end = self.format_timestamp(segment["end"])
-            segment_text = segment['text'].strip().replace('-->', '->')
+            segment_text = segment["text"].strip().replace("-->", "->")
 
 
             if word_timings := segment.get("words", None):
             if word_timings := segment.get("words", None):
                 all_words = [timing["word"] for timing in word_timings]
                 all_words = [timing["word"] for timing in word_timings]
@@ -106,7 +115,10 @@ class SubtitlesWriter(ResultWriter):
                         yield last, start, segment_text
                         yield last, start, segment_text
 
 
                     yield start, end, "".join(
                     yield start, end, "".join(
-                        [f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)]
+                        [
+                            f"<u>{word}</u>" if j == i else word
+                            for j, word in enumerate(all_words)
+                        ]
                     )
                     )
                     last = end
                     last = end
 
 
@@ -126,7 +138,7 @@ class SubtitlesWriter(ResultWriter):
 class WriteVTT(SubtitlesWriter):
 class WriteVTT(SubtitlesWriter):
     extension: str = "vtt"
     extension: str = "vtt"
     always_include_hours: bool = False
     always_include_hours: bool = False
-    decimal_marker: str = '.'
+    decimal_marker: str = "."
 
 
     def write_result(self, result: dict, file: TextIO):
     def write_result(self, result: dict, file: TextIO):
         print("WEBVTT\n", file=file)
         print("WEBVTT\n", file=file)
@@ -137,7 +149,7 @@ class WriteVTT(SubtitlesWriter):
 class WriteSRT(SubtitlesWriter):
 class WriteSRT(SubtitlesWriter):
     extension: str = "srt"
     extension: str = "srt"
     always_include_hours: bool = True
     always_include_hours: bool = True
-    decimal_marker: str = ','
+    decimal_marker: str = ","
 
 
     def write_result(self, result: dict, file: TextIO):
     def write_result(self, result: dict, file: TextIO):
         for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
         for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
@@ -153,14 +165,15 @@ class WriteTSV(ResultWriter):
     an environment setting a language encoding that causes the decimal in a floating point number
     an environment setting a language encoding that causes the decimal in a floating point number
     to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
     to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
     """
     """
+
     extension: str = "tsv"
     extension: str = "tsv"
 
 
     def write_result(self, result: dict, file: TextIO):
     def write_result(self, result: dict, file: TextIO):
         print("start", "end", "text", sep="\t", file=file)
         print("start", "end", "text", sep="\t", file=file)
         for segment in result["segments"]:
         for segment in result["segments"]:
-            print(round(1000 * segment['start']), file=file, end="\t")
-            print(round(1000 * segment['end']), file=file, end="\t")
-            print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
+            print(round(1000 * segment["start"]), file=file, end="\t")
+            print(round(1000 * segment["end"]), file=file, end="\t")
+            print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
 
 
 
 
 class WriteJSON(ResultWriter):
 class WriteJSON(ResultWriter):
@@ -189,4 +202,3 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
         return write_all
         return write_all
 
 
     return writers[output_format](output_dir)
     return writers[output_format](output_dir)
-