| 
					
				 | 
			
			
				@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union, TYPE_CHECKING 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import tqdm 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from .decoding import DecodingOptions, DecodingResult 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -87,7 +88,7 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         _, probs = model.detect_language(segment) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         decode_options["language"] = max(probs, key=probs.get) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(f"Detected language: {LANGUAGES[decode_options['language']]}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     mel = mel.unsqueeze(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     language = decode_options["language"] 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -160,72 +161,81 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if verbose: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    while seek < mel.shape[-1]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        decode_options["prompt"] = all_tokens[prompt_reset_since:] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        result = decode_with_fallback(segment)[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        tokens = torch.tensor(result.tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if no_speech_threshold is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            # no voice activity check 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            should_skip = result.no_speech_prob > no_speech_threshold 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if logprob_threshold is not None and result.avg_logprob > logprob_threshold: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # don't skip if the logprob is high enough, despite the no_speech_prob 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                should_skip = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if should_skip: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                seek += segment.shape[-1]  # fast-forward to the next segment boundary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            last_slice = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for current_slice in consecutive: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                sliced_tokens = tokens[last_slice:current_slice] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                start_timestamp_position = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    sliced_tokens[0].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                end_timestamp_position = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    sliced_tokens[-1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # show the progress bar when verbose is False (otherwise the transcribed text will be printed) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    num_frames = mel.shape[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    previous_seek_value = seek 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose) as pbar: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        while seek < num_frames: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            decode_options["prompt"] = all_tokens[prompt_reset_since:] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            result = decode_with_fallback(segment)[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tokens = torch.tensor(result.tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if no_speech_threshold is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # no voice activity check 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                should_skip = result.no_speech_prob > no_speech_threshold 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if logprob_threshold is not None and result.avg_logprob > logprob_threshold: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # don't skip if the logprob is high enough, despite the no_speech_prob 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    should_skip = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if should_skip: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    seek += segment.shape[-1]  # fast-forward to the next segment boundary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                last_slice = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for current_slice in consecutive: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    sliced_tokens = tokens[last_slice:current_slice] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    start_timestamp_position = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        sliced_tokens[0].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    end_timestamp_position = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        sliced_tokens[-1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    add_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        start=timestamp_offset + start_timestamp_position * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        end=timestamp_offset + end_timestamp_position * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        text_tokens=sliced_tokens[1:-1], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        result=result, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    last_slice = current_slice 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                last_timestamp_position = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                seek += last_timestamp_position * input_stride 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                all_tokens.extend(tokens[: last_slice + 1].tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                duration = segment_duration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                timestamps = tokens[timestamp_tokens.nonzero().flatten()] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if len(timestamps) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # no consecutive timestamps but it has a timestamp; use the last one. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # single timestamp at the end means no speech after the last timestamp. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    duration = last_timestamp_position * time_precision 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 add_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    start=timestamp_offset + start_timestamp_position * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    end=timestamp_offset + end_timestamp_position * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    text_tokens=sliced_tokens[1:-1], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    start=timestamp_offset, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    end=timestamp_offset + duration, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    text_tokens=tokens, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     result=result, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                last_slice = current_slice 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            last_timestamp_position = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                tokens[last_slice - 1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            seek += last_timestamp_position * input_stride 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            all_tokens.extend(tokens[: last_slice + 1].tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            duration = segment_duration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            timestamps = tokens[timestamp_tokens.nonzero().flatten()] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if len(timestamps) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # no consecutive timestamps but it has a timestamp; use the last one. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # single timestamp at the end means no speech after the last timestamp. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                duration = last_timestamp_position * time_precision 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            add_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                start=timestamp_offset, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                end=timestamp_offset + duration, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                text_tokens=tokens, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                result=result, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            seek += segment.shape[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            all_tokens.extend(tokens.tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if not condition_on_previous_text or result.temperature > 0.5: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            # do not feed the prompt tokens if a high temperature was used 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            prompt_reset_since = len(all_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                seek += segment.shape[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                all_tokens.extend(tokens.tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if not condition_on_previous_text or result.temperature > 0.5: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # do not feed the prompt tokens if a high temperature was used 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt_reset_since = len(all_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # update progress bar 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            pbar.update(min(num_frames, seek) - previous_seek_value) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            previous_seek_value = seek 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |