Browse Source

fix: transcribe verbosity (#140)

Nick Konovalchuk 2 years ago
parent
commit
b4308c4782
1 changed files with 6 additions and 4 deletions
  1. 6 4
      whisper/transcribe.py

+ 6 - 4
whisper/transcribe.py

@@ -20,7 +20,7 @@ def transcribe(
     model: "Whisper",
     audio: Union[str, np.ndarray, torch.Tensor],
     *,
-    verbose: bool = False,
+    verbose: Optional[bool] = None,
     temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
     compression_ratio_threshold: Optional[float] = 2.4,
     logprob_threshold: Optional[float] = -1.0,
@@ -40,7 +40,8 @@ def transcribe(
         The path to the audio file to open, or the audio waveform
 
     verbose: bool
-        Whether to display the text being decoded to the console
+        Whether to display the text being decoded to the console. If True, displays all the details,
+        If False, displays minimal details. If None, does not display anything
 
     temperature: Union[float, Tuple[float, ...]]
         Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
@@ -88,7 +89,8 @@ def transcribe(
         segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
         _, probs = model.detect_language(segment)
         decode_options["language"] = max(probs, key=probs.get)
-        print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
+        if verbose is not None:
+            print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
 
     mel = mel.unsqueeze(0)
     language = decode_options["language"]
@@ -170,7 +172,7 @@ def transcribe(
     num_frames = mel.shape[-1]
     previous_seek_value = seek
 
-    with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose) as pbar:
+    with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
         while seek < num_frames:
             timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
             segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)