Jelajahi Sumber

transcribe() on English-only model won't complain when language="en" is not given

Jong Wook Kim 2 tahun lalu
induk
melakukan
d18e9ea5dd
1 mengubah file dengan 12 tambahan dan 8 penghapusan
  1. 12 8
      whisper/transcribe.py

+ 12 - 8
whisper/transcribe.py

@@ -84,13 +84,16 @@ def transcribe(
     mel = log_mel_spectrogram(audio)
 
     if decode_options.get("language", None) is None:
-        if verbose:
-            print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
-        segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
-        _, probs = model.detect_language(segment)
-        decode_options["language"] = max(probs, key=probs.get)
-        if verbose is not None:
-            print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
+        if not model.is_multilingual:
+            decode_options["language"] = "en"
+        else:
+            if verbose:
+                print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
+            segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
+            _, probs = model.detect_language(segment)
+            decode_options["language"] = max(probs, key=probs.get)
+            if verbose is not None:
+                print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
 
     language = decode_options["language"]
     task = decode_options.get("task", "transcribe")
@@ -282,7 +285,8 @@ def cli():
     os.makedirs(output_dir, exist_ok=True)
 
     if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
-        warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
+        if args["language"] is not None:
+            warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
         args["language"] = "en"
 
     temperature = args.pop("temperature")