Browse Source

nocaptions -> nospeech to match the paper figure

Jong Wook Kim 2 years ago
parent
commit
15ab548263
3 changed files with 27 additions and 39 deletions
  1. 14 14
      whisper/decoding.py
  2. 3 3
      whisper/tokenizer.py
  3. 10 22
      whisper/transcribe.py

+ 14 - 14
whisper/decoding.py

@@ -108,7 +108,7 @@ class DecodingResult:
     tokens: List[int] = field(default_factory=list)
     text: str = ""
     avg_logprob: float = np.nan
-    no_caption_prob: float = np.nan
+    no_speech_prob: float = np.nan
     temperature: float = np.nan
     compression_ratio: float = np.nan
 
@@ -543,9 +543,9 @@ class DecodingTask:
         suppress_tokens.extend(
             [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
         )
-        if self.tokenizer.no_captions is not None:
-            # no-captions probability is collected separately
-            suppress_tokens.append(self.tokenizer.no_captions)
+        if self.tokenizer.no_speech is not None:
+            # no-speech probability is collected separately
+            suppress_tokens.append(self.tokenizer.no_speech)
 
         return tuple(sorted(set(suppress_tokens)))
 
@@ -580,15 +580,15 @@ class DecodingTask:
         assert audio_features.shape[0] == tokens.shape[0]
         n_batch = tokens.shape[0]
         sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
-        no_caption_probs = [np.nan] * n_batch
+        no_speech_probs = [np.nan] * n_batch
 
         try:
             for i in range(self.sample_len):
                 logits = self.inference.logits(tokens, audio_features)
 
-                if i == 0 and self.tokenizer.no_captions is not None:  # save no_caption_probs
+                if i == 0 and self.tokenizer.no_speech is not None:  # save no_speech_probs
                     probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
-                    no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist()
+                    no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
 
                 # now we need to consider the logits at the last token only
                 logits = logits[:, -1]
@@ -605,7 +605,7 @@ class DecodingTask:
         finally:
             self.inference.cleanup_caching()
 
-        return tokens, sum_logprobs, no_caption_probs
+        return tokens, sum_logprobs, no_speech_probs
 
     @torch.no_grad()
     def run(self, mel: Tensor) -> List[DecodingResult]:
@@ -629,12 +629,12 @@ class DecodingTask:
         tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
 
         # call the main sampling loop
-        tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens)
+        tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
 
         # reshape the tensors to have (n_audio, n_group) as the first two dimensions
         audio_features = audio_features[:: self.n_group]
-        no_caption_probs = no_caption_probs[:: self.n_group]
-        assert audio_features.shape[0] == len(no_caption_probs) == n_audio
+        no_speech_probs = no_speech_probs[:: self.n_group]
+        assert audio_features.shape[0] == len(no_speech_probs) == n_audio
 
         tokens = tokens.reshape(n_audio, self.n_group, -1)
         sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
@@ -653,7 +653,7 @@ class DecodingTask:
         sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
         avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
 
-        fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs)
+        fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
         if len(set(map(len, fields))) != 1:
             raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
 
@@ -664,11 +664,11 @@ class DecodingTask:
                 tokens=tokens,
                 text=text,
                 avg_logprob=avg_logprob,
-                no_caption_prob=no_caption_prob,
+                no_speech_prob=no_speech_prob,
                 temperature=self.options.temperature,
                 compression_ratio=compression_ratio(text),
             )
-            for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields)
+            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
         ]
 
 

+ 3 - 3
whisper/tokenizer.py

@@ -178,8 +178,8 @@ class Tokenizer:
 
     @property
     @lru_cache()
-    def no_captions(self) -> int:
-        return self._get_single_token_id("<|nocaptions|>")
+    def no_speech(self) -> int:
+        return self._get_single_token_id("<|nospeech|>")
 
     @property
     @lru_cache()
@@ -283,7 +283,7 @@ def build_tokenizer(name: str = "gpt2"):
         "<|transcribe|>",
         "<|startoflm|>",
         "<|startofprev|>",
-        "<|nocaptions|>",
+        "<|nospeech|>",
         "<|notimestamps|>",
     ]
 

+ 10 - 22
whisper/transcribe.py

@@ -23,7 +23,7 @@ def transcribe(
     temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
     compression_ratio_threshold: Optional[float] = 2.4,
     logprob_threshold: Optional[float] = -1.0,
-    no_captions_threshold: Optional[float] = 0.6,
+    no_speech_threshold: Optional[float] = 0.6,
     **decode_options,
 ):
     """
@@ -50,8 +50,8 @@ def transcribe(
     logprob_threshold: float
         If the average log probability over sampled tokens is below this value, treat as failed
 
-    no_captions_threshold: float
-        If the no_captions probability is higher than this value AND the average log probability
+    no_speech_threshold: float
+        If the no_speech probability is higher than this value AND the average log probability
         over sampled tokens is below `logprob_threshold`, consider the segment as silent
 
     decode_options: dict
@@ -148,7 +148,7 @@ def transcribe(
                 "temperature": result.temperature,
                 "avg_logprob": result.avg_logprob,
                 "compression_ratio": result.compression_ratio,
-                "no_caption_prob": result.no_caption_prob,
+                "no_speech_prob": result.no_speech_prob,
             }
         )
         if verbose:
@@ -163,11 +163,11 @@ def transcribe(
         result = decode_with_fallback(segment)[0]
         tokens = torch.tensor(result.tokens)
 
-        if no_captions_threshold is not None:
+        if no_speech_threshold is not None:
             # no voice activity check
-            should_skip = result.no_caption_prob > no_captions_threshold
+            should_skip = result.no_speech_prob > no_speech_threshold
             if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
-                # don't skip if the logprob is high enough, despite the no_captions_prob
+                # don't skip if the logprob is high enough, despite the no_speech_prob
                 should_skip = False
 
             if should_skip:
@@ -249,7 +249,7 @@ def cli():
     parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
     parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
     parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
-    parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
+    parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
 
     args = parser.parse_args().__dict__
     model_name: str = args.pop("model")
@@ -261,12 +261,8 @@ def cli():
         warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
         args["language"] = "en"
 
-    temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
-    compression_ratio_threshold = args.pop("compression_ratio_threshold")
-    logprob_threshold = args.pop("logprob_threshold")
-    no_caption_threshold = args.pop("no_caption_threshold")
-
     temperature = args.pop("temperature")
+    temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
     if temperature_increment_on_fallback is not None:
         temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
     else:
@@ -276,15 +272,7 @@ def cli():
     model = load_model(model_name, device=device)
 
     for audio_path in args.pop("audio"):
-        result = transcribe(
-            model,
-            audio_path,
-            temperature=temperature,
-            compression_ratio_threshold=compression_ratio_threshold,
-            logprob_threshold=logprob_threshold,
-            no_captions_threshold=no_caption_threshold,
-            **args,
-        )
+        result = transcribe(model, audio_path, temperature=temperature, **args)
 
         audio_basename = os.path.basename(audio_path)