test_transcribe.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. import pytest
  3. import torch
  4. import whisper
  5. from whisper.tokenizer import get_tokenizer
  6. @pytest.mark.parametrize("model_name", whisper.available_models())
  7. def test_transcribe(model_name: str):
  8. device = "cuda" if torch.cuda.is_available() else "cpu"
  9. model = whisper.load_model(model_name).to(device)
  10. audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
  11. language = "en" if model_name.endswith(".en") else None
  12. result = model.transcribe(
  13. audio_path, language=language, temperature=0.0, word_timestamps=True
  14. )
  15. assert result["language"] == "en"
  16. assert result["text"] == "".join([s["text"] for s in result["segments"]])
  17. transcription = result["text"].lower()
  18. assert "my fellow americans" in transcription
  19. assert "your country" in transcription
  20. assert "do for you" in transcription
  21. tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
  22. all_tokens = [t for s in result["segments"] for t in s["tokens"]]
  23. assert tokenizer.decode(all_tokens) == result["text"]
  24. assert tokenizer.decode_with_timestamps(all_tokens).startswith("<|0.00|>")
  25. timing_checked = False
  26. for segment in result["segments"]:
  27. for timing in segment["words"]:
  28. assert timing["start"] < timing["end"]
  29. if timing["word"].strip(" ,") == "Americans":
  30. assert timing["start"] <= 1.8
  31. assert timing["end"] >= 1.8
  32. timing_checked = True
  33. assert timing_checked