Browse Source

added progress callback

Jhjoon05 2 years ago
parent
commit
47c5120305
1 changed files with 6 additions and 1 deletions
  1. 6 1
      whisper/transcribe.py

+ 6 - 1
whisper/transcribe.py

@@ -1,7 +1,7 @@
 import argparse
 import os
 import warnings
-from typing import Optional, Tuple, Union, TYPE_CHECKING
+from typing import Optional, Tuple, Union, Callable, TYPE_CHECKING
 
 import numpy as np
 import torch
@@ -27,6 +27,7 @@ def transcribe(
     no_speech_threshold: Optional[float] = 0.6,
     condition_on_previous_text: bool = True,
     initial_prompt: Optional[str] = None,
+    progress_callback: Optional[Callable[[float],None]] = None,
     **decode_options,
 ):
     """
@@ -175,6 +176,10 @@ def transcribe(
 
     with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
         while seek < num_frames:
+            if progress_callback is not None:
+                progress_value = seek/num_frames
+                progress_callback(progress_value)
+
             timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
             segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
             segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE