|
@@ -26,6 +26,7 @@ def transcribe(
|
|
logprob_threshold: Optional[float] = -1.0,
|
|
logprob_threshold: Optional[float] = -1.0,
|
|
no_speech_threshold: Optional[float] = 0.6,
|
|
no_speech_threshold: Optional[float] = 0.6,
|
|
condition_on_previous_text: bool = True,
|
|
condition_on_previous_text: bool = True,
|
|
|
|
+ initial_prompt: Optional[str] = None,
|
|
**decode_options,
|
|
**decode_options,
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
@@ -138,10 +139,11 @@ def transcribe(
|
|
all_segments = []
|
|
all_segments = []
|
|
prompt_reset_since = 0
|
|
prompt_reset_since = 0
|
|
|
|
|
|
- initial_prompt = decode_options.pop("initial_prompt", None) or []
|
|
|
|
- if initial_prompt:
|
|
|
|
- initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
|
|
|
|
- all_tokens.extend(initial_prompt)
|
|
|
|
|
|
+ if initial_prompt is not None:
|
|
|
|
+ initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
|
|
|
+ all_tokens.extend(initial_prompt_tokens)
|
|
|
|
+ else:
|
|
|
|
+ initial_prompt_tokens = []
|
|
|
|
|
|
def add_segment(
|
|
def add_segment(
|
|
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
|
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
|
@@ -243,7 +245,11 @@ def transcribe(
|
|
pbar.update(min(num_frames, seek) - previous_seek_value)
|
|
pbar.update(min(num_frames, seek) - previous_seek_value)
|
|
previous_seek_value = seek
|
|
previous_seek_value = seek
|
|
|
|
|
|
- return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
|
|
|
|
|
+ return dict(
|
|
|
|
+ text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
|
|
|
|
+ segments=all_segments,
|
|
|
|
+ language=language
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
def cli():
|
|
def cli():
|
|
@@ -292,21 +298,18 @@ def cli():
|
|
args["language"] = "en"
|
|
args["language"] = "en"
|
|
|
|
|
|
temperature = args.pop("temperature")
|
|
temperature = args.pop("temperature")
|
|
- temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
|
|
|
- if temperature_increment_on_fallback is not None:
|
|
|
|
- temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
|
|
|
|
|
+ if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
|
|
|
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
|
else:
|
|
else:
|
|
temperature = [temperature]
|
|
temperature = [temperature]
|
|
|
|
|
|
- threads = args.pop("threads")
|
|
|
|
- if threads > 0:
|
|
|
|
|
|
+ if (threads := args.pop("threads")) > 0:
|
|
torch.set_num_threads(threads)
|
|
torch.set_num_threads(threads)
|
|
|
|
|
|
from . import load_model
|
|
from . import load_model
|
|
model = load_model(model_name, device=device, download_root=model_dir)
|
|
model = load_model(model_name, device=device, download_root=model_dir)
|
|
|
|
|
|
writer = get_writer(output_format, output_dir)
|
|
writer = get_writer(output_format, output_dir)
|
|
-
|
|
|
|
for audio_path in args.pop("audio"):
|
|
for audio_path in args.pop("audio"):
|
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
writer(result, audio_path)
|
|
writer(result, audio_path)
|