|
@@ -1,13 +1,15 @@
|
|
|
import os
|
|
|
|
|
|
import pytest
|
|
|
+import torch
|
|
|
|
|
|
import whisper
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize('model_name', whisper.available_models())
|
|
|
+@pytest.mark.parametrize("model_name", whisper.available_models())
|
|
|
def test_transcribe(model_name: str):
|
|
|
- model = whisper.load_model(model_name).cuda()
|
|
|
+ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ model = whisper.load_model(model_name).to(device)
|
|
|
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
|
|
|
|
|
language = "en" if model_name.endswith(".en") else None
|