test_transcribe.py 622 B

1234567891011121314151617181920
  1. import os
  2. import pytest
  3. import whisper
  4. @pytest.mark.parametrize('model_name', whisper.available_models())
  5. def test_transcribe(model_name: str):
  6. model = whisper.load_model(model_name).cuda()
  7. audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
  8. language = "en" if model_name.endswith(".en") else None
  9. result = model.transcribe(audio_path, language=language, temperature=0.0)
  10. assert result["language"] == "en"
  11. transcription = result["text"].lower()
  12. assert "my fellow americans" in transcription
  13. assert "your country" in transcription
  14. assert "do for you" in transcription