|
@@ -1,7 +1,7 @@
|
|
|
import argparse
|
|
|
import os
|
|
|
import warnings
|
|
|
-from typing import Optional, Tuple, Union, TYPE_CHECKING
|
|
|
+from typing import Optional, Tuple, Union, Callable, TYPE_CHECKING
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
@@ -27,6 +27,7 @@ def transcribe(
|
|
|
no_speech_threshold: Optional[float] = 0.6,
|
|
|
condition_on_previous_text: bool = True,
|
|
|
initial_prompt: Optional[str] = None,
|
|
|
+ progress_callback: Optional[Callable[[float],None]] = None,
|
|
|
**decode_options,
|
|
|
):
|
|
|
"""
|
|
@@ -175,6 +176,10 @@ def transcribe(
|
|
|
|
|
|
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
|
|
while seek < num_frames:
|
|
|
+ if progress_callback is not None:
|
|
|
+ progress_value = seek/num_frames
|
|
|
+ progress_callback(progress_value)
|
|
|
+
|
|
|
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
|
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|