test_transcribe.py 699 B

12345678910111213141516171819202122
  1. import os
  2. import pytest
  3. import torch
  4. import whisper
  5. @pytest.mark.parametrize("model_name", whisper.available_models())
  6. def test_transcribe(model_name: str):
  7. device = "cuda" if torch.cuda.is_available() else "cpu"
  8. model = whisper.load_model(model_name).to(device)
  9. audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
  10. language = "en" if model_name.endswith(".en") else None
  11. result = model.transcribe(audio_path, language=language, temperature=0.0)
  12. assert result["language"] == "en"
  13. transcription = result["text"].lower()
  14. assert "my fellow americans" in transcription
  15. assert "your country" in transcription
  16. assert "do for you" in transcription