| 
					
				 | 
			
			
				@@ -26,6 +26,7 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     logprob_threshold: Optional[float] = -1.0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     no_speech_threshold: Optional[float] = 0.6, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     condition_on_previous_text: bool = True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    initial_prompt: Optional[str] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     **decode_options, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -138,10 +139,11 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     all_segments = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     prompt_reset_since = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    initial_prompt = decode_options.pop("initial_prompt", None) or [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if initial_prompt: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        all_tokens.extend(initial_prompt) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if initial_prompt is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        all_tokens.extend(initial_prompt_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        initial_prompt_tokens = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def add_segment( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -243,7 +245,11 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             pbar.update(min(num_frames, seek) - previous_seek_value) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             previous_seek_value = seek 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return dict( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        segments=all_segments, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        language=language 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def cli(): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -292,21 +298,18 @@ def cli(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         args["language"] = "en" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     temperature = args.pop("temperature") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if temperature_increment_on_fallback is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if (increment := args.pop("temperature_increment_on_fallback")) is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         temperature = [temperature] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    threads = args.pop("threads") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if threads > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if (threads := args.pop("threads")) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         torch.set_num_threads(threads) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     from . import load_model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model = load_model(model_name, device=device, download_root=model_dir) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     writer = get_writer(output_format, output_dir) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     for audio_path in args.pop("audio"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         result = transcribe(model, audio_path, temperature=temperature, **args) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         writer(result, audio_path) 
			 |