فهرست منبع

drop python 3.7 support (#889)

Jong Wook Kim 2 سال پیش
والد
کامیت
a6b36ede1f
3فایلهای تغییر یافته به همراه33 افزوده شده و 49 حذف شده
  1. 6 13
      whisper/decoding.py
  2. 13 25
      whisper/tokenizer.py
  3. 14 11
      whisper/transcribe.py

+ 6 - 13
whisper/decoding.py

@@ -252,11 +252,10 @@ class GreedyDecoder(TokenDecoder):
         self.eot = eot
 
     def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
-        temperature = self.temperature
-        if temperature == 0:
+        if self.temperature == 0:
             next_tokens = logits.argmax(dim=-1)
         else:
-            next_tokens = Categorical(logits=logits / temperature).sample()
+            next_tokens = Categorical(logits=logits / self.temperature).sample()
 
         logprobs = F.log_softmax(logits.float(), dim=-1)
         current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
@@ -511,10 +510,8 @@ class DecodingTask:
 
     def _get_initial_tokens(self) -> Tuple[int]:
         tokens = list(self.sot_sequence)
-        prefix = self.options.prefix
-        prompt = self.options.prompt
 
-        if prefix:
+        if prefix := self.options.prefix:
             prefix_tokens = (
                 self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
             )
@@ -523,7 +520,7 @@ class DecodingTask:
                 prefix_tokens = prefix_tokens[-max_prefix_len:]
             tokens = tokens + prefix_tokens
 
-        if prompt:
+        if prompt := self.options.prompt:
             prompt_tokens = (
                 self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
             )
@@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
     result: Union[DecodingResult, List[DecodingResult]]
         The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
     """
-    single = mel.ndim == 2
-    if single:
+    if single := mel.ndim == 2:
         mel = mel.unsqueeze(0)
 
     result = DecodingTask(model, options).run(mel)
-    
-    if single:
-        result = result[0]
 
-    return result
+    return result[0] if single else result

+ 13 - 25
whisper/tokenizer.py

@@ -1,6 +1,6 @@
 import os
 from dataclasses import dataclass
-from functools import lru_cache
+from functools import lru_cache, cached_property
 from typing import List, Optional, Tuple, Union
 
 import numpy as np
@@ -156,43 +156,35 @@ class Tokenizer:
         outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
         return "".join(outputs)
 
-    @property
-    @lru_cache()
+    @cached_property
     def eot(self) -> int:
         return self.tokenizer.eos_token_id
 
-    @property
-    @lru_cache()
+    @cached_property
     def sot(self) -> int:
         return self._get_single_token_id("<|startoftranscript|>")
 
-    @property
-    @lru_cache()
+    @cached_property
     def sot_lm(self) -> int:
         return self._get_single_token_id("<|startoflm|>")
 
-    @property
-    @lru_cache()
+    @cached_property
     def sot_prev(self) -> int:
         return self._get_single_token_id("<|startofprev|>")
 
-    @property
-    @lru_cache()
+    @cached_property
     def no_speech(self) -> int:
         return self._get_single_token_id("<|nospeech|>")
 
-    @property
-    @lru_cache()
+    @cached_property
     def no_timestamps(self) -> int:
         return self._get_single_token_id("<|notimestamps|>")
 
-    @property
-    @lru_cache()
+    @cached_property
     def timestamp_begin(self) -> int:
         return self.tokenizer.all_special_ids[-1] + 1
 
-    @property
-    @lru_cache()
+    @cached_property
     def language_token(self) -> int:
         """Returns the token id corresponding to the value of the `language` field"""
         if self.language is None:
@@ -210,8 +202,7 @@ class Tokenizer:
 
         raise KeyError(f"Language {self.language} not found in tokenizer.")
 
-    @property
-    @lru_cache()
+    @cached_property
     def all_language_tokens(self) -> Tuple[int]:
         result = []
         for token, token_id in zip(
@@ -222,18 +213,15 @@ class Tokenizer:
                 result.append(token_id)
         return tuple(result)
 
-    @property
-    @lru_cache()
+    @cached_property
     def all_language_codes(self) -> Tuple[str]:
         return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
 
-    @property
-    @lru_cache()
+    @cached_property
     def sot_sequence_including_notimestamps(self) -> Tuple[int]:
         return tuple(list(self.sot_sequence) + [self.no_timestamps])
 
-    @property
-    @lru_cache()
+    @cached_property
     def non_speech_tokens(self) -> Tuple[int]:
         """
         Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech

+ 14 - 11
whisper/transcribe.py

@@ -26,6 +26,7 @@ def transcribe(
     logprob_threshold: Optional[float] = -1.0,
     no_speech_threshold: Optional[float] = 0.6,
     condition_on_previous_text: bool = True,
+    initial_prompt: Optional[str] = None,
     **decode_options,
 ):
     """
@@ -138,10 +139,11 @@ def transcribe(
     all_segments = []
     prompt_reset_since = 0
 
-    initial_prompt = decode_options.pop("initial_prompt", None) or []
-    if initial_prompt:
-        initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
-        all_tokens.extend(initial_prompt)
+    if initial_prompt is not None:
+        initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
+        all_tokens.extend(initial_prompt_tokens)
+    else:
+        initial_prompt_tokens = []
 
     def add_segment(
         *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
@@ -243,7 +245,11 @@ def transcribe(
             pbar.update(min(num_frames, seek) - previous_seek_value)
             previous_seek_value = seek
 
-    return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
+    return dict(
+        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
+        segments=all_segments,
+        language=language
+    )
 
 
 def cli():
@@ -292,21 +298,18 @@ def cli():
         args["language"] = "en"
 
     temperature = args.pop("temperature")
-    temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
-    if temperature_increment_on_fallback is not None:
-        temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
+    if (increment := args.pop("temperature_increment_on_fallback")) is not None:
+        temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
     else:
         temperature = [temperature]
 
-    threads = args.pop("threads")
-    if threads > 0:
+    if (threads := args.pop("threads")) > 0:
         torch.set_num_threads(threads)
 
     from . import load_model
     model = load_model(model_name, device=device, download_root=model_dir)
 
     writer = get_writer(output_format, output_dir)
-
     for audio_path in args.pop("audio"):
         result = transcribe(model, audio_path, temperature=temperature, **args)
         writer(result, audio_path)