瀏覽代碼

allowing nonzero initial temperature

Jong Wook Kim 2 年之前
父節點
當前提交
7cb4cc21bf
共有 2 個文件被更改,包括 29 次插入33 次删除
  1. 1 1
      whisper/decoding.py
  2. 28 32
      whisper/transcribe.py

+ 1 - 1
whisper/decoding.py

@@ -94,7 +94,7 @@ class DecodingOptions:
 
     # timestamp sampling options
     without_timestamps: bool = False              # use <|notimestamps|> to sample text tokens only
-    max_initial_timestamp: Optional[float] = 0.0  # the initial timestamp cannot be later than this
+    max_initial_timestamp: Optional[float] = 1.0  # the initial timestamp cannot be later than this
 
     # implementation details
     fp16: bool = True  # use fp16 for most of the calculation

+ 28 - 32
whisper/transcribe.py

@@ -92,41 +92,37 @@ def transcribe(
         if verbose is not None:
             print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
 
-    mel = mel.unsqueeze(0)
     language = decode_options["language"]
     task = decode_options.get("task", "transcribe")
     tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
 
-    def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]:
+    def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
         temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
-        kwargs = {**decode_options}
-        t = temperatures[0]
-        if t == 0:
-            best_of = kwargs.pop("best_of", None)
-        else:
-            best_of = kwargs.get("best_of", None)
-
-        options = DecodingOptions(**kwargs, temperature=t)
-        results = model.decode(segment, options)
-
-        kwargs.pop("beam_size", None)  # no beam search for t > 0
-        kwargs.pop("patience", None)  # no patience for t > 0
-        kwargs["best_of"] = best_of  # enable best_of for t > 0
-        for t in temperatures[1:]:
-            needs_fallback = [
-                compression_ratio_threshold is not None
-                and result.compression_ratio > compression_ratio_threshold
-                or logprob_threshold is not None
-                and result.avg_logprob < logprob_threshold
-                for result in results
-            ]
-            if any(needs_fallback):
-                options = DecodingOptions(**kwargs, temperature=t)
-                retries = model.decode(segment[needs_fallback], options)
-                for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
-                    results[original_index] = retries[retry_index]
-
-        return results
+        decode_result = None
+
+        for t in temperatures:
+            kwargs = {**decode_options}
+            if t > 0:
+                # disable beam_size and patience when t > 0
+                kwargs.pop("beam_size", None)
+                kwargs.pop("patience", None)
+            else:
+                # disable best_of when t == 0
+                kwargs.pop("best_of", None)
+
+            options = DecodingOptions(**kwargs, temperature=t)
+            decode_result = model.decode(segment, options)
+
+            needs_fallback = False
+            if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
+                needs_fallback = True  # too repetitive
+            if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
+                needs_fallback = True  # average log probability is too low
+
+            if not needs_fallback:
+                break
+
+        return decode_result
 
     seek = 0
     input_stride = exact_div(
@@ -175,11 +171,11 @@ def transcribe(
     with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
         while seek < num_frames:
             timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
-            segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
+            segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
             segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
 
             decode_options["prompt"] = all_tokens[prompt_reset_since:]
-            result = decode_with_fallback(segment)[0]
+            result: DecodingResult = decode_with_fallback(segment)
             tokens = torch.tensor(result.tokens)
 
             if no_speech_threshold is not None: