Browse Source

nocaptions -> nospeech to match the paper figure

Jong Wook Kim 2 years ago
parent
commit
15ab548263
3 changed files with 27 additions and 39 deletions
  1. 14 14
      whisper/decoding.py
  2. 3 3
      whisper/tokenizer.py
  3. 10 22
      whisper/transcribe.py

+ 14 - 14
whisper/decoding.py

@@ -108,7 +108,7 @@ class DecodingResult:
     tokens: List[int] = field(default_factory=list)
     tokens: List[int] = field(default_factory=list)
     text: str = ""
     text: str = ""
     avg_logprob: float = np.nan
     avg_logprob: float = np.nan
-    no_caption_prob: float = np.nan
+    no_speech_prob: float = np.nan
     temperature: float = np.nan
     temperature: float = np.nan
     compression_ratio: float = np.nan
     compression_ratio: float = np.nan
 
 
@@ -543,9 +543,9 @@ class DecodingTask:
         suppress_tokens.extend(
         suppress_tokens.extend(
             [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
             [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
         )
         )
-        if self.tokenizer.no_captions is not None:
-            # no-captions probability is collected separately
-            suppress_tokens.append(self.tokenizer.no_captions)
+        if self.tokenizer.no_speech is not None:
+            # no-speech probability is collected separately
+            suppress_tokens.append(self.tokenizer.no_speech)
 
 
         return tuple(sorted(set(suppress_tokens)))
         return tuple(sorted(set(suppress_tokens)))
 
 
@@ -580,15 +580,15 @@ class DecodingTask:
         assert audio_features.shape[0] == tokens.shape[0]
         assert audio_features.shape[0] == tokens.shape[0]
         n_batch = tokens.shape[0]
         n_batch = tokens.shape[0]
         sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
         sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
-        no_caption_probs = [np.nan] * n_batch
+        no_speech_probs = [np.nan] * n_batch
 
 
         try:
         try:
             for i in range(self.sample_len):
             for i in range(self.sample_len):
                 logits = self.inference.logits(tokens, audio_features)
                 logits = self.inference.logits(tokens, audio_features)
 
 
-                if i == 0 and self.tokenizer.no_captions is not None:  # save no_caption_probs
+                if i == 0 and self.tokenizer.no_speech is not None:  # save no_speech_probs
                     probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                     probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
-                    no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist()
+                    no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
 
 
                 # now we need to consider the logits at the last token only
                 # now we need to consider the logits at the last token only
                 logits = logits[:, -1]
                 logits = logits[:, -1]
@@ -605,7 +605,7 @@ class DecodingTask:
         finally:
         finally:
             self.inference.cleanup_caching()
             self.inference.cleanup_caching()
 
 
-        return tokens, sum_logprobs, no_caption_probs
+        return tokens, sum_logprobs, no_speech_probs
 
 
     @torch.no_grad()
     @torch.no_grad()
     def run(self, mel: Tensor) -> List[DecodingResult]:
     def run(self, mel: Tensor) -> List[DecodingResult]:
@@ -629,12 +629,12 @@ class DecodingTask:
         tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
         tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
 
 
         # call the main sampling loop
         # call the main sampling loop
-        tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens)
+        tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
 
 
         # reshape the tensors to have (n_audio, n_group) as the first two dimensions
         # reshape the tensors to have (n_audio, n_group) as the first two dimensions
         audio_features = audio_features[:: self.n_group]
         audio_features = audio_features[:: self.n_group]
-        no_caption_probs = no_caption_probs[:: self.n_group]
-        assert audio_features.shape[0] == len(no_caption_probs) == n_audio
+        no_speech_probs = no_speech_probs[:: self.n_group]
+        assert audio_features.shape[0] == len(no_speech_probs) == n_audio
 
 
         tokens = tokens.reshape(n_audio, self.n_group, -1)
         tokens = tokens.reshape(n_audio, self.n_group, -1)
         sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
         sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
@@ -653,7 +653,7 @@ class DecodingTask:
         sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
         sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
         avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
         avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
 
 
-        fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs)
+        fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
         if len(set(map(len, fields))) != 1:
         if len(set(map(len, fields))) != 1:
             raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
             raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
 
 
@@ -664,11 +664,11 @@ class DecodingTask:
                 tokens=tokens,
                 tokens=tokens,
                 text=text,
                 text=text,
                 avg_logprob=avg_logprob,
                 avg_logprob=avg_logprob,
-                no_caption_prob=no_caption_prob,
+                no_speech_prob=no_speech_prob,
                 temperature=self.options.temperature,
                 temperature=self.options.temperature,
                 compression_ratio=compression_ratio(text),
                 compression_ratio=compression_ratio(text),
             )
             )
-            for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields)
+            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
         ]
         ]
 
 
 
 

+ 3 - 3
whisper/tokenizer.py

@@ -178,8 +178,8 @@ class Tokenizer:
 
 
     @property
     @property
     @lru_cache()
     @lru_cache()
-    def no_captions(self) -> int:
-        return self._get_single_token_id("<|nocaptions|>")
+    def no_speech(self) -> int:
+        return self._get_single_token_id("<|nospeech|>")
 
 
     @property
     @property
     @lru_cache()
     @lru_cache()
@@ -283,7 +283,7 @@ def build_tokenizer(name: str = "gpt2"):
         "<|transcribe|>",
         "<|transcribe|>",
         "<|startoflm|>",
         "<|startoflm|>",
         "<|startofprev|>",
         "<|startofprev|>",
-        "<|nocaptions|>",
+        "<|nospeech|>",
         "<|notimestamps|>",
         "<|notimestamps|>",
     ]
     ]
 
 

+ 10 - 22
whisper/transcribe.py

@@ -23,7 +23,7 @@ def transcribe(
     temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
     temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
     compression_ratio_threshold: Optional[float] = 2.4,
     compression_ratio_threshold: Optional[float] = 2.4,
     logprob_threshold: Optional[float] = -1.0,
     logprob_threshold: Optional[float] = -1.0,
-    no_captions_threshold: Optional[float] = 0.6,
+    no_speech_threshold: Optional[float] = 0.6,
     **decode_options,
     **decode_options,
 ):
 ):
     """
     """
@@ -50,8 +50,8 @@ def transcribe(
     logprob_threshold: float
     logprob_threshold: float
         If the average log probability over sampled tokens is below this value, treat as failed
         If the average log probability over sampled tokens is below this value, treat as failed
 
 
-    no_captions_threshold: float
-        If the no_captions probability is higher than this value AND the average log probability
+    no_speech_threshold: float
+        If the no_speech probability is higher than this value AND the average log probability
         over sampled tokens is below `logprob_threshold`, consider the segment as silent
         over sampled tokens is below `logprob_threshold`, consider the segment as silent
 
 
     decode_options: dict
     decode_options: dict
@@ -148,7 +148,7 @@ def transcribe(
                 "temperature": result.temperature,
                 "temperature": result.temperature,
                 "avg_logprob": result.avg_logprob,
                 "avg_logprob": result.avg_logprob,
                 "compression_ratio": result.compression_ratio,
                 "compression_ratio": result.compression_ratio,
-                "no_caption_prob": result.no_caption_prob,
+                "no_speech_prob": result.no_speech_prob,
             }
             }
         )
         )
         if verbose:
         if verbose:
@@ -163,11 +163,11 @@ def transcribe(
         result = decode_with_fallback(segment)[0]
         result = decode_with_fallback(segment)[0]
         tokens = torch.tensor(result.tokens)
         tokens = torch.tensor(result.tokens)
 
 
-        if no_captions_threshold is not None:
+        if no_speech_threshold is not None:
             # no voice activity check
             # no voice activity check
-            should_skip = result.no_caption_prob > no_captions_threshold
+            should_skip = result.no_speech_prob > no_speech_threshold
             if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
             if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
-                # don't skip if the logprob is high enough, despite the no_captions_prob
+                # don't skip if the logprob is high enough, despite the no_speech_prob
                 should_skip = False
                 should_skip = False
 
 
             if should_skip:
             if should_skip:
@@ -249,7 +249,7 @@ def cli():
     parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
     parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
     parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
     parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
     parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
     parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
-    parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
+    parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
 
 
     args = parser.parse_args().__dict__
     args = parser.parse_args().__dict__
     model_name: str = args.pop("model")
     model_name: str = args.pop("model")
@@ -261,12 +261,8 @@ def cli():
         warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
         warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
         args["language"] = "en"
         args["language"] = "en"
 
 
-    temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
-    compression_ratio_threshold = args.pop("compression_ratio_threshold")
-    logprob_threshold = args.pop("logprob_threshold")
-    no_caption_threshold = args.pop("no_caption_threshold")
-
     temperature = args.pop("temperature")
     temperature = args.pop("temperature")
+    temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
     if temperature_increment_on_fallback is not None:
     if temperature_increment_on_fallback is not None:
         temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
         temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
     else:
     else:
@@ -276,15 +272,7 @@ def cli():
     model = load_model(model_name, device=device)
     model = load_model(model_name, device=device)
 
 
     for audio_path in args.pop("audio"):
     for audio_path in args.pop("audio"):
-        result = transcribe(
-            model,
-            audio_path,
-            temperature=temperature,
-            compression_ratio_threshold=compression_ratio_threshold,
-            logprob_threshold=logprob_threshold,
-            no_captions_threshold=no_caption_threshold,
-            **args,
-        )
+        result = transcribe(model, audio_path, temperature=temperature, **args)
 
 
         audio_basename = os.path.basename(audio_path)
         audio_basename = os.path.basename(audio_path)