12345678910111213141516171819202122 |
- import os
- import pytest
- import torch
- import whisper
- @pytest.mark.parametrize("model_name", whisper.available_models())
- def test_transcribe(model_name: str):
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model = whisper.load_model(model_name).to(device)
- audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
- language = "en" if model_name.endswith(".en") else None
- result = model.transcribe(audio_path, language=language, temperature=0.0)
- assert result["language"] == "en"
- transcription = result["text"].lower()
- assert "my fellow americans" in transcription
- assert "your country" in transcription
- assert "do for you" in transcription
|