| 
					
				 | 
			
			
				@@ -11,6 +11,7 @@ from .audio import ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     FRAMES_PER_SECOND, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     HOP_LENGTH, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     N_FRAMES, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    N_SAMPLES, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     SAMPLE_RATE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     log_mel_spectrogram, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     pad_or_trim, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -116,7 +117,9 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if dtype == torch.float32: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         decode_options["fp16"] = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    mel = log_mel_spectrogram(audio) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Pad 30-seconds of silence to the input audio, for slicing 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    mel = log_mel_spectrogram(audio, padding=N_SAMPLES) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    content_frames = mel.shape[-1] - N_FRAMES 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if decode_options.get("language", None) is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not model.is_multilingual: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -212,14 +215,13 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # show the progress bar when verbose is False (if True, transcribed text will be printed) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    num_frames = mel.shape[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     with tqdm.tqdm( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        total=num_frames, unit="frames", disable=verbose is not False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        total=content_frames, unit="frames", disable=verbose is not False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) as pbar: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        while seek < num_frames: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        while seek < content_frames: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            mel_segment = mel[:, seek:] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            segment_size = min(mel_segment.shape[-1], N_FRAMES) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mel_segment = mel[:, seek : seek + N_FRAMES] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            segment_size = min(N_FRAMES, content_frames - seek) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -246,20 +248,18 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             current_tokens = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ].add_(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                len(consecutive) > 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ):  # if the output contains two consecutive timestamp tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    False, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                ]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    consecutive = consecutive.tolist() + [len(tokens)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            consecutive.add_(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if len(consecutive) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # if the output contains two consecutive timestamp tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                slices = consecutive.tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if single_timestamp_ending: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    slices.append(len(tokens)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 last_slice = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for current_slice in consecutive: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for current_slice in slices: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     sliced_tokens = tokens[last_slice:current_slice] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     start_timestamp_pos = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         sliced_tokens[0].item() - tokenizer.timestamp_begin 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -278,7 +278,7 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     current_tokens.append(sliced_tokens.tolist()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     last_slice = current_slice 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if ended_with_single_timestamp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if single_timestamp_ending: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     # single timestamp at the end means no speech after the last timestamp. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     seek += segment_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 else: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -329,7 +329,7 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 word_end_timestamps = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     w["end"] for s in current_segments for w in s["words"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if len(consecutive) > 0 and len(word_end_timestamps) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if not single_timestamp_ending and len(word_end_timestamps) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     seek_shift = round( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     ) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -356,7 +356,7 @@ def transcribe( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # update progress bar 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            pbar.update(min(num_frames, seek) - previous_seek) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            pbar.update(min(content_frames, seek) - previous_seek) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return dict( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), 
			 |