test_transcribe.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. import os
  2. import pytest
  3. import torch
  4. import whisper
  5. @pytest.mark.parametrize("model_name", whisper.available_models())
  6. def test_transcribe(model_name: str):
  7. device = "cuda" if torch.cuda.is_available() else "cpu"
  8. model = whisper.load_model(model_name).to(device)
  9. audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
  10. language = "en" if model_name.endswith(".en") else None
  11. result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
  12. assert result["language"] == "en"
  13. transcription = result["text"].lower()
  14. assert "my fellow americans" in transcription
  15. assert "your country" in transcription
  16. assert "do for you" in transcription
  17. timing_checked = False
  18. for segment in result["segments"]:
  19. for timing in segment["words"]:
  20. assert timing["start"] < timing["end"]
  21. if timing["word"].strip(" ,") == "Americans":
  22. assert timing["start"] <= 1.8
  23. assert timing["end"] >= 1.8
  24. print(timing)
  25. timing_checked = True
  26. assert timing_checked