| 
					
				 | 
			
			
				@@ -1,5 +1,7 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import zlib 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from typing import Iterator, TextIO 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from typing import Callable, TextIO 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def exact_div(x, y): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def write_txt(transcript: Iterator[dict], file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    for segment in transcript: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(segment['text'].strip(), file=file, flush=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def write_vtt(transcript: Iterator[dict], file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    print("WEBVTT\n", file=file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    for segment in transcript: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            f"{segment['text'].strip().replace('-->', '->')}\n", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            file=file, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            flush=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def write_srt(transcript: Iterator[dict], file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    Write a transcript to a file in SRT format. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    Example usage: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        from pathlib import Path 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        from whisper.utils import write_srt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        result = transcribe(model, audio_path, temperature=temperature, **args) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # save SRT 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        audio_basename = Path(audio_path).stem 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            write_srt(result["segments"], file=srt) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    for i, segment in enumerate(transcript, start=1): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # write srt lines 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            f"{i}\n" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            f"{segment['text'].strip().replace('-->', '->')}\n", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            file=file, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            flush=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class ResultWriter: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    extension: str 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, output_dir: str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.output_dir = output_dir 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __call__(self, result: dict, audio_path: str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        audio_basename = os.path.basename(audio_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with open(output_path, "w", encoding="utf-8") as f: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.write_result(result, file=f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def write_result(self, result: dict, file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        raise NotImplementedError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class WriteTXT(ResultWriter): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    extension: str = "txt" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def write_result(self, result: dict, file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for segment in result["segments"]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(segment['text'].strip(), file=file, flush=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class WriteVTT(ResultWriter): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    extension: str = "vtt" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def write_result(self, result: dict, file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print("WEBVTT\n", file=file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for segment in result["segments"]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f"{segment['text'].strip().replace('-->', '->')}\n", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                file=file, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                flush=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class WriteSRT(ResultWriter): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    extension: str = "srt" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def write_result(self, result: dict, file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for i, segment in enumerate(result["segments"], start=1): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # write srt lines 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f"{i}\n" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f"{segment['text'].strip().replace('-->', '->')}\n", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                file=file, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                flush=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class WriteJSON(ResultWriter): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    extension: str = "json" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def write_result(self, result: dict, file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        json.dump(result, file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    writers = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        "txt": WriteTXT, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        "vtt": WriteVTT, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        "srt": WriteSRT, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        "json": WriteJSON, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if output_format == "all": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        all_writers = [writer(output_dir) for writer in writers.values()] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        def write_all(result: dict, file: TextIO): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for writer in all_writers: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                writer(result, file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return write_all 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return writers[output_format](output_dir) 
			 |