Просмотр исходного кода

large-v3 (#1761)

* mel_filters() loads 128 mel bins

* can load 100-language models

* large-v3 checkpoint and evals

* add mandarin alias

* remove unused path

* flake8 fix

* formatting fix
Jong Wook Kim 1 год назад
Родитель
Сommit
c5d4256076

+ 2 - 2
README.md

@@ -69,9 +69,9 @@ There are five model sizes, four with English-only versions, offering speed and
 
 The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
 
-Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model (The smaller the numbers, the better the performance). Additional WER scores corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4. Meanwhile, more BLEU (Bilingual Evaluation Understudy) scores can be found in Appendix D.3. Both are found in [the paper](https://arxiv.org/abs/2212.04356). 
+Whisper's performance varies widely depending on the language. The figure below shows a performance breakdown of `large-v3` and `large-v2` models by language, using WERs (word error rates) or CER (character error rates, shown in *Italic*) evaluated on the Common Voice 15 and Fleurs datasets. Additional WER/CER metrics corresponding to the other models and datasets can be found in Appendix D.1, D.2, and D.4 of [the paper](https://arxiv.org/abs/2212.04356), as well as the BLEU (Bilingual Evaluation Understudy) scores for translation in Appendix D.3.
 
-![WER breakdown by language](https://raw.githubusercontent.com/openai/whisper/main/language-breakdown.svg)
+![WER breakdown by language](https://github.com/openai/whisper/assets/266841/f4619d66-1058-4005-8f67-a9d811b77c62)
 
 
 

Разница между файлами не показана из-за своего большого размера
+ 278 - 658
language-breakdown.svg


+ 2 - 2
model-card.md

@@ -17,12 +17,12 @@ The Whisper models are trained for speech recognition and translation tasks, cap
 | medium |   769 M    |         ✓          |         ✓          |
 | large  |   1550 M   |                    |         ✓          |
 
-In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661).
+In December 2022, we [released an improved large model named `large-v2`](https://github.com/openai/whisper/discussions/661), and `large-v3` in November 2023.
 
 
 ### Release date
 
-September 2022 (original series) and December 2022 (`large-v2`)
+September 2022 (original series), December 2022 (`large-v2`), and November 2023 (`large-v3`)
 
 ### Model type
 

+ 1 - 1
tests/test_transcribe.py

@@ -25,7 +25,7 @@ def test_transcribe(model_name: str):
     assert "your country" in transcription
     assert "do for you" in transcription
 
-    tokenizer = get_tokenizer(model.is_multilingual)
+    tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
     all_tokens = [t for s in result["segments"] for t in s["tokens"]]
     assert tokenizer.decode(all_tokens) == result["text"]
     assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")

+ 4 - 2
whisper/__init__.py

@@ -25,7 +25,8 @@ _MODELS = {
     "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
     "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
     "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
-    "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
+    "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
+    "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
 }
 
 # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
@@ -41,7 +42,8 @@ _ALIGNMENT_HEADS = {
     "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
     "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
     "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
-    "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
+    "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
+    "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
 }
 
 

BIN
whisper/assets/mel_filters.npz


+ 4 - 4
whisper/audio.py

@@ -12,7 +12,6 @@ from .utils import exact_div
 # hard-coded audio hyperparameters
 SAMPLE_RATE = 16000
 N_FFT = 400
-N_MELS = 80
 HOP_LENGTH = 160
 CHUNK_LENGTH = 30
 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
@@ -90,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
 
 
 @lru_cache(maxsize=None)
-def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
+def mel_filters(device, n_mels: int) -> torch.Tensor:
     """
     load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
     Allows decoupling librosa dependency; saved using:
@@ -98,9 +97,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
         np.savez_compressed(
             "mel_filters.npz",
             mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+            mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
         )
     """
-    assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+    assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
 
     filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
     with np.load(filters_path, allow_pickle=False) as f:
@@ -109,7 +109,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
 
 def log_mel_spectrogram(
     audio: Union[str, np.ndarray, torch.Tensor],
-    n_mels: int = N_MELS,
+    n_mels: int = 80,
     padding: int = 0,
     device: Optional[Union[str, torch.device]] = None,
 ):

+ 7 - 2
whisper/decoding.py

@@ -32,7 +32,9 @@ def detect_language(
         list of dictionaries containing the probability distribution over all languages.
     """
     if tokenizer is None:
-        tokenizer = get_tokenizer(model.is_multilingual)
+        tokenizer = get_tokenizer(
+            model.is_multilingual, num_languages=model.num_languages
+        )
     if (
         tokenizer.language is None
         or tokenizer.language_token not in tokenizer.sot_sequence
@@ -514,7 +516,10 @@ class DecodingTask:
 
         language = options.language or "en"
         tokenizer = get_tokenizer(
-            model.is_multilingual, language=language, task=options.task
+            model.is_multilingual,
+            num_languages=model.num_languages,
+            language=language,
+            task=options.task,
         )
         self.tokenizer: Tokenizer = tokenizer
         self.options: DecodingOptions = self._verify_options(options)

+ 7 - 2
whisper/model.py

@@ -236,7 +236,8 @@ class Whisper(nn.Module):
             self.dims.n_text_head,
             self.dims.n_text_layer,
         )
-        # use the last half layers for alignment by default; see `set_alignment_heads()` below
+        # use the last half among the decoder layers for time alignment by default;
+        # to use a specific set of heads, see `set_alignment_heads()` below.
         all_heads = torch.zeros(
             self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
         )
@@ -269,7 +270,11 @@ class Whisper(nn.Module):
 
     @property
     def is_multilingual(self):
-        return self.dims.n_vocab == 51865
+        return self.dims.n_vocab >= 51865
+
+    @property
+    def num_languages(self):
+        return self.dims.n_vocab - 51765 - int(self.is_multilingual)
 
     def install_kv_cache_hooks(self, cache: Optional[dict] = None):
         """

+ 18 - 9
whisper/tokenizer.py

@@ -107,6 +107,7 @@ LANGUAGES = {
     "ba": "bashkir",
     "jw": "javanese",
     "su": "sundanese",
+    "yue": "cantonese",
 }
 
 # language code lookup by name, with a few language aliases
@@ -123,6 +124,7 @@ TO_LANGUAGE_CODE = {
     "moldovan": "ro",
     "sinhalese": "si",
     "castilian": "es",
+    "mandarin": "zh",
 }
 
 
@@ -131,6 +133,7 @@ class Tokenizer:
     """A thin wrapper around `tiktoken` providing quick access to special tokens"""
 
     encoding: tiktoken.Encoding
+    num_languages: int
     language: Optional[str] = None
     task: Optional[str] = None
     sot_sequence: Tuple[int] = ()
@@ -145,7 +148,7 @@ class Tokenizer:
         translate: int = self.special_tokens["<|translate|>"]
         transcribe: int = self.special_tokens["<|transcribe|>"]
 
-        langs = tuple(LANGUAGES.keys())
+        langs = tuple(LANGUAGES.keys())[: self.num_languages]
         sot_sequence = [sot]
         if self.language is not None:
             sot_sequence.append(sot + 1 + langs.index(self.language))
@@ -211,10 +214,13 @@ class Tokenizer:
         if self.language is None:
             raise ValueError("This tokenizer does not have language token configured")
 
-        if token := self.special_tokens.get(f"<|{self.language}|>", None):
+        return self.to_language_token(self.language)
+
+    def to_language_token(self, language):
+        if token := self.special_tokens.get(f"<|{language}|>", None):
             return token
 
-        raise KeyError(f"Language {self.language} not found in tokenizer.")
+        raise KeyError(f"Language {language} not found in tokenizer.")
 
     @cached_property
     def all_language_tokens(self) -> Tuple[int]:
@@ -222,7 +228,7 @@ class Tokenizer:
         for token, token_id in self.special_tokens.items():
             if token.strip("<|>") in LANGUAGES:
                 result.append(token_id)
-        return tuple(result)
+        return tuple(result)[: self.num_languages]
 
     @cached_property
     def all_language_codes(self) -> Tuple[str]:
@@ -269,7 +275,7 @@ class Tokenizer:
         return tuple(sorted(result))
 
     def split_to_word_tokens(self, tokens: List[int]):
-        if self.language in {"zh", "ja", "th", "lo", "my"}:
+        if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
             # These languages don't typically use spaces, so it is difficult to split words
             # without morpheme analysis. Here, we instead split words at any
             # position where the tokens are decoded as valid unicode points
@@ -322,7 +328,7 @@ class Tokenizer:
 
 
 @lru_cache(maxsize=None)
-def get_encoding(name: str = "gpt2"):
+def get_encoding(name: str = "gpt2", num_languages: int = 99):
     vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
     ranks = {
         base64.b64decode(token): int(rank)
@@ -334,7 +340,7 @@ def get_encoding(name: str = "gpt2"):
     specials = [
         "<|endoftext|>",
         "<|startoftranscript|>",
-        *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
+        *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
         "<|translate|>",
         "<|transcribe|>",
         "<|startoflm|>",
@@ -361,6 +367,7 @@ def get_encoding(name: str = "gpt2"):
 def get_tokenizer(
     multilingual: bool,
     *,
+    num_languages: int = 99,
     language: Optional[str] = None,
     task: Optional[str] = None,  # Literal["transcribe", "translate", None]
 ) -> Tokenizer:
@@ -381,6 +388,8 @@ def get_tokenizer(
         language = None
         task = None
 
-    encoding = get_encoding(name=encoding_name)
+    encoding = get_encoding(name=encoding_name, num_languages=num_languages)
 
-    return Tokenizer(encoding=encoding, language=language, task=task)
+    return Tokenizer(
+        encoding=encoding, num_languages=num_languages, language=language, task=task
+    )

+ 7 - 2
whisper/transcribe.py

@@ -119,7 +119,7 @@ def transcribe(
         decode_options["fp16"] = False
 
     # Pad 30-seconds of silence to the input audio, for slicing
-    mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
+    mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
     content_frames = mel.shape[-1] - N_FRAMES
 
     if decode_options.get("language", None) is None:
@@ -140,7 +140,12 @@ def transcribe(
 
     language: str = decode_options["language"]
     task: str = decode_options.get("task", "transcribe")
-    tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
+    tokenizer = get_tokenizer(
+        model.is_multilingual,
+        num_languages=model.num_languages,
+        language=language,
+        task=task,
+    )
 
     if word_timestamps and task == "translate":
         warnings.warn("Word-level timestamps on translations may not be reliable.")

Некоторые файлы не были показаны из-за большого количества измененных файлов