瀏覽代碼

fix-space-issue

jhj0517 9 月之前
父節點
當前提交
9a97a9b254
共有 2 個文件被更改,包括 28 次插入6 次删除
  1. 11 6
      whisper/transcribe.py
  2. 17 0
      whisper/utils.py

+ 11 - 6
whisper/transcribe.py

@@ -10,7 +10,7 @@ import tqdm
 from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
 from .decoding import DecodingOptions, DecodingResult
 from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
-from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer
+from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer, remove_leading_spaces
 
 if TYPE_CHECKING:
     from .model import Whisper
@@ -250,11 +250,16 @@ def transcribe(
             pbar.update(min(num_frames, seek) - previous_seek_value)
             previous_seek_value = seek
 
-    return dict(
-        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
-        segments=all_segments,
-        language=language
-    )
+            result = dict(
+                text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
+                segments=all_segments,
+                language=language
+            )
+
+            if decode_options["language"] == "ko":
+                result = remove_leading_spaces(result)
+
+    return result
 
 
 def cli():

+ 17 - 0
whisper/utils.py

@@ -161,3 +161,20 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
 
     return writers[output_format](output_dir)
 
+
+def remove_leading_spaces(
+        result: dict
+) -> dict:
+    """
+    Removes unwanted leading spaces from the main 'text' field and each 'text' field in the 'segments' list.
+    This function is currently applied specifically to some languages to correct formatting issues.
+    Currently monitored language: Korean
+    """
+    if result['text'].startswith(' '):
+        result['text'] = result['text'][1:]
+
+    for segment in result['segments']:
+        if segment['text'].startswith(' '):
+            segment['text'] = segment['text'][1:]
+
+    return result