test_transcribe.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import os
  2. import pytest
  3. import torch
  4. import whisper
  5. @pytest.mark.parametrize("model_name", whisper.available_models())
  6. def test_transcribe(model_name: str):
  7. device = "cuda" if torch.cuda.is_available() else "cpu"
  8. model = whisper.load_model(model_name).to(device)
  9. audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
  10. language = "en" if model_name.endswith(".en") else None
  11. result = model.transcribe(
  12. audio_path, language=language, temperature=0.0, word_timestamps=True
  13. )
  14. assert result["language"] == "en"
  15. transcription = result["text"].lower()
  16. assert "my fellow americans" in transcription
  17. assert "your country" in transcription
  18. assert "do for you" in transcription
  19. timing_checked = False
  20. for segment in result["segments"]:
  21. for timing in segment["words"]:
  22. assert timing["start"] < timing["end"]
  23. if timing["word"].strip(" ,") == "Americans":
  24. assert timing["start"] <= 1.8
  25. assert timing["end"] >= 1.8
  26. print(timing)
  27. timing_checked = True
  28. assert timing_checked