Browse Source

add progress bar for transcribe loop (#100)

* add progress bar to transcribe loop

* improved warning message for English-only models

* add --condition_on_previous_text

* progressbar renames

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
fatih 2 years ago
parent
commit
9e7e418ff1
1 changed files with 73 additions and 63 deletions
  1. 73 63
      whisper/transcribe.py

+ 73 - 63
whisper/transcribe.py

@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union, TYPE_CHECKING
 
 import numpy as np
 import torch
+import tqdm
 
 from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
 from .decoding import DecodingOptions, DecodingResult
@@ -87,7 +88,7 @@ def transcribe(
         segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
         _, probs = model.detect_language(segment)
         decode_options["language"] = max(probs, key=probs.get)
-        print(f"Detected language: {LANGUAGES[decode_options['language']]}")
+        print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
 
     mel = mel.unsqueeze(0)
     language = decode_options["language"]
@@ -160,72 +161,81 @@ def transcribe(
         if verbose:
             print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
 
-    while seek < mel.shape[-1]:
-        timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
-        segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
-        segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
-
-        decode_options["prompt"] = all_tokens[prompt_reset_since:]
-        result = decode_with_fallback(segment)[0]
-        tokens = torch.tensor(result.tokens)
-
-        if no_speech_threshold is not None:
-            # no voice activity check
-            should_skip = result.no_speech_prob > no_speech_threshold
-            if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
-                # don't skip if the logprob is high enough, despite the no_speech_prob
-                should_skip = False
-
-            if should_skip:
-                seek += segment.shape[-1]  # fast-forward to the next segment boundary
-                continue
-
-        timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
-        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
-        if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
-            last_slice = 0
-            for current_slice in consecutive:
-                sliced_tokens = tokens[last_slice:current_slice]
-                start_timestamp_position = (
-                    sliced_tokens[0].item() - tokenizer.timestamp_begin
-                )
-                end_timestamp_position = (
-                    sliced_tokens[-1].item() - tokenizer.timestamp_begin
+    # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
+    num_frames = mel.shape[-1]
+    previous_seek_value = seek
+
+    with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose) as pbar:
+        while seek < num_frames:
+            timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
+            segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
+            segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
+
+            decode_options["prompt"] = all_tokens[prompt_reset_since:]
+            result = decode_with_fallback(segment)[0]
+            tokens = torch.tensor(result.tokens)
+
+            if no_speech_threshold is not None:
+                # no voice activity check
+                should_skip = result.no_speech_prob > no_speech_threshold
+                if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
+                    # don't skip if the logprob is high enough, despite the no_speech_prob
+                    should_skip = False
+
+                if should_skip:
+                    seek += segment.shape[-1]  # fast-forward to the next segment boundary
+                    continue
+
+            timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
+            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
+            if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens
+                last_slice = 0
+                for current_slice in consecutive:
+                    sliced_tokens = tokens[last_slice:current_slice]
+                    start_timestamp_position = (
+                        sliced_tokens[0].item() - tokenizer.timestamp_begin
+                    )
+                    end_timestamp_position = (
+                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
+                    )
+                    add_segment(
+                        start=timestamp_offset + start_timestamp_position * time_precision,
+                        end=timestamp_offset + end_timestamp_position * time_precision,
+                        text_tokens=sliced_tokens[1:-1],
+                        result=result,
+                    )
+                    last_slice = current_slice
+                last_timestamp_position = (
+                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                 )
+                seek += last_timestamp_position * input_stride
+                all_tokens.extend(tokens[: last_slice + 1].tolist())
+            else:
+                duration = segment_duration
+                timestamps = tokens[timestamp_tokens.nonzero().flatten()]
+                if len(timestamps) > 0:
+                    # no consecutive timestamps but it has a timestamp; use the last one.
+                    # single timestamp at the end means no speech after the last timestamp.
+                    last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
+                    duration = last_timestamp_position * time_precision
+
                 add_segment(
-                    start=timestamp_offset + start_timestamp_position * time_precision,
-                    end=timestamp_offset + end_timestamp_position * time_precision,
-                    text_tokens=sliced_tokens[1:-1],
+                    start=timestamp_offset,
+                    end=timestamp_offset + duration,
+                    text_tokens=tokens,
                     result=result,
                 )
-                last_slice = current_slice
-            last_timestamp_position = (
-                tokens[last_slice - 1].item() - tokenizer.timestamp_begin
-            )
-            seek += last_timestamp_position * input_stride
-            all_tokens.extend(tokens[: last_slice + 1].tolist())
-        else:
-            duration = segment_duration
-            timestamps = tokens[timestamp_tokens.nonzero().flatten()]
-            if len(timestamps) > 0:
-                # no consecutive timestamps but it has a timestamp; use the last one.
-                # single timestamp at the end means no speech after the last timestamp.
-                last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
-                duration = last_timestamp_position * time_precision
-
-            add_segment(
-                start=timestamp_offset,
-                end=timestamp_offset + duration,
-                text_tokens=tokens,
-                result=result,
-            )
-
-            seek += segment.shape[-1]
-            all_tokens.extend(tokens.tolist())
-
-        if not condition_on_previous_text or result.temperature > 0.5:
-            # do not feed the prompt tokens if a high temperature was used
-            prompt_reset_since = len(all_tokens)
+
+                seek += segment.shape[-1]
+                all_tokens.extend(tokens.tolist())
+
+            if not condition_on_previous_text or result.temperature > 0.5:
+                # do not feed the prompt tokens if a high temperature was used
+                prompt_reset_since = len(all_tokens)
+
+            # update progress bar
+            pbar.update(min(num_frames, seek) - previous_seek_value)
+            previous_seek_value = seek
 
     return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)