Kaynağa Gözat

handle printing even if sys.stdout.buffer is not available (#887)

Jong Wook Kim 2 yıl önce
ebeveyn
işleme
7f1ef223ab
2 değiştirilmiş dosya ile 16 ekleme ve 8 silme
  1. 3 8
      whisper/transcribe.py
  2. 13 0
      whisper/utils.py

+ 3 - 8
whisper/transcribe.py

@@ -1,8 +1,7 @@
 import argparse
 import os
-import sys
 import warnings
-from typing import List, Optional, Tuple, Union, TYPE_CHECKING
+from typing import Optional, Tuple, Union, TYPE_CHECKING
 
 import numpy as np
 import torch
@@ -11,7 +10,7 @@ import tqdm
 from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
 from .decoding import DecodingOptions, DecodingResult
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
-from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, get_writer
+from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
 
 if TYPE_CHECKING:
     from .model import Whisper
@@ -166,11 +165,7 @@ def transcribe(
             }
         )
         if verbose:
-            line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}\n"
-            # compared to just `print(line)`, this replaces any character not representable using
-            # the system default encoding with an '?', avoiding UnicodeEncodeError.
-            sys.stdout.buffer.write(line.encode(sys.getdefaultencoding(), errors="replace"))
-            sys.stdout.flush()
+            print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"))
 
     # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
     num_frames = mel.shape[-1]

+ 13 - 0
whisper/utils.py

@@ -1,8 +1,21 @@
 import json
 import os
+import sys
 import zlib
 from typing import Callable, TextIO
 
+system_encoding = sys.getdefaultencoding()
+
+if system_encoding != "utf-8":
+    def make_safe(string):
+        # replaces any character not representable using the system default encoding with an '?',
+        # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
+        return string.encode(system_encoding, errors="replace").decode(system_encoding)
+else:
+    def make_safe(string):
+        # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
+        return string
+
 
 def exact_div(x, y):
     assert x % y == 0