| 
					
				 | 
			
			
				@@ -7,8 +7,9 @@ import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import tqdm 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from .decoding import DecodingOptions, DecodingResult 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from .timing import add_word_timestamps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -27,6 +28,9 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     no_speech_threshold: Optional[float] = 0.6, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     condition_on_previous_text: bool = True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     initial_prompt: Optional[str] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    word_timestamps: bool = False, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    prepend_punctuations: str = "\"\'“¿([{-", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    append_punctuations: str = "\"\'.。,,!!??::”)]}、", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     **decode_options, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -63,6 +67,21 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         disabling may make the text inconsistent across windows, but the model becomes less prone to 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    word_timestamps: bool 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Extract word-level timestamps using the cross-attention pattern and dynamic time warping, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        and include the timestamps for each word in each segment. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    prepend_punctuations: str 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        If word_timestamps is True, merge these punctuation symbols with the next word 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    append_punctuations: str 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        If word_timestamps is True, merge these punctuation symbols with the previous word 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    initial_prompt: Optional[str] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Optional text to provide as a prompt for the first window. This can be used to provide, or 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        to make it more likely to predict those word correctly. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     decode_options: dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Keyword arguments to construct `DecodingOptions` instances 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -90,16 +109,19 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if verbose: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            _, probs = model.detect_language(segment) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            _, probs = model.detect_language(mel_segment) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             decode_options["language"] = max(probs, key=probs.get) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if verbose is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    language = decode_options["language"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    task = decode_options.get("task", "transcribe") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    language: str = decode_options["language"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    task: str = decode_options.get("task", "transcribe") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if word_timestamps and task == "translate": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        warnings.warn("Word-level timestamps on translations may not be reliable.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         decode_result = None 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -145,42 +167,35 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         initial_prompt_tokens = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def add_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def new_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if len(text.strip()) == 0:  # skip empty text output 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        all_segments.append( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "id": len(all_segments), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "seek": seek, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "start": start, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "end": end, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "text": text, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "tokens": text_tokens.tolist(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "temperature": result.temperature, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "avg_logprob": result.avg_logprob, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "compression_ratio": result.compression_ratio, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "no_speech_prob": result.no_speech_prob, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if verbose: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # show the progress bar when verbose is False (otherwise the transcribed text will be printed) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "id": len(all_segments), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "seek": seek, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "start": start, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "end": end, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "text": tokenizer.decode(text_tokens), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "tokens": text_tokens, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "temperature": result.temperature, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "avg_logprob": result.avg_logprob, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "compression_ratio": result.compression_ratio, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "no_speech_prob": result.no_speech_prob, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # show the progress bar when verbose is False (if True, transcribed text will be printed) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     num_frames = mel.shape[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    previous_seek_value = seek 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         while seek < num_frames: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mel_segment = mel[:, seek:] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            segment_size = min(mel_segment.shape[-1], N_FRAMES) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             decode_options["prompt"] = all_tokens[prompt_reset_since:] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            result: DecodingResult = decode_with_fallback(segment) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            result: DecodingResult = decode_with_fallback(mel_segment) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             tokens = torch.tensor(result.tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if no_speech_threshold is not None: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -191,29 +206,36 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     should_skip = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if should_skip: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    seek += segment.shape[-1]  # fast-forward to the next segment boundary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    seek += segment_size  # fast-forward to the next segment boundary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            previous_seek = seek 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            current_segments = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            current_tokens = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if len(consecutive) > 0:  # if the output contains two consecutive timestamp tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     consecutive = consecutive.tolist() + [len(tokens)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 last_slice = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 for current_slice in consecutive: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     sliced_tokens = tokens[last_slice:current_slice] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    add_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        start=timestamp_offset + start_timestamp_pos * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        end=timestamp_offset + end_timestamp_pos * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        text_tokens=sliced_tokens[1:-1], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    current_segments.append(new_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        start=time_offset + start_timestamp_pos * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        end=time_offset + end_timestamp_pos * time_precision, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        tokens=sliced_tokens, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         result=result, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    current_tokens.append(sliced_tokens.tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     last_slice = current_slice 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if ended_with_single_timestamp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     # single timestamp at the end means no speech after the last timestamp. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    seek += segment.shape[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    seek += segment_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     # otherwise, ignore the unfinished segment and seek to the last timestamp 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -227,23 +249,54 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     duration = last_timestamp_pos * time_precision 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                add_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    start=timestamp_offset, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    end=timestamp_offset + duration, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    text_tokens=tokens, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                current_segments.append(new_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    start=time_offset, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    end=time_offset + duration, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    tokens=tokens, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     result=result, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                seek += segment.shape[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                all_tokens.extend(tokens.tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                current_tokens.append(tokens.tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                seek += segment_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if not condition_on_previous_text or result.temperature > 0.5: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 # do not feed the prompt tokens if a high temperature was used 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 prompt_reset_since = len(all_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if word_timestamps: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                add_word_timestamps( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    segments=current_segments, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    tokenizer=tokenizer, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    mel=mel_segment, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    num_frames=segment_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    prepend_punctuations=prepend_punctuations, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    append_punctuations=append_punctuations, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if len(consecutive) > 0 and len(word_end_timestamps) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if seek_shift > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        seek = previous_seek + seek_shift 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if verbose: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for segment in current_segments: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    start, end, text = segment["start"], segment["end"], segment["text"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(make_safe(line)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # if a segment is instantaneous or does not contain text, clear it 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i, segment in enumerate(current_segments): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if segment["start"] == segment["end"] or segment["text"].strip() == "": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    segment["text"] = "" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    segment["tokens"] = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    segment["words"] = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    current_tokens[i] = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            all_segments.extend(current_segments) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            all_tokens.extend([token for segment in current_tokens for token in segment]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # update progress bar 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            pbar.update(min(num_frames, seek) - previous_seek_value) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            previous_seek_value = seek 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            pbar.update(min(num_frames, seek) - previous_seek) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return dict( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -282,6 +335,9 @@ def cli(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     args = parser.parse_args().__dict__ 
			 |