test_transcribe.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import os
  2. import pytest
  3. import torch
  4. import whisper
  5. @pytest.mark.parametrize("model_name", whisper.available_models())
  6. def test_transcribe(model_name: str):
  7. device = "cuda" if torch.cuda.is_available() else "cpu"
  8. model = whisper.load_model(model_name).to(device)
  9. audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
  10. language = "en" if model_name.endswith(".en") else None
  11. result = model.transcribe(
  12. audio_path, language=language, temperature=0.0, word_timestamps=True
  13. )
  14. assert result["language"] == "en"
  15. assert result["text"] == "".join([s["text"] for s in result["segments"]])
  16. transcription = result["text"].lower()
  17. assert "my fellow americans" in transcription
  18. assert "your country" in transcription
  19. assert "do for you" in transcription
  20. timing_checked = False
  21. for segment in result["segments"]:
  22. for timing in segment["words"]:
  23. assert timing["start"] < timing["end"]
  24. if timing["word"].strip(" ,") == "Americans":
  25. assert timing["start"] <= 1.8
  26. assert timing["end"] >= 1.8
  27. print(timing)
  28. timing_checked = True
  29. assert timing_checked