test_tokenizer.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334
  1. import pytest
  2. from whisper.tokenizer import get_tokenizer
  3. @pytest.mark.parametrize("multilingual", [True, False])
  4. def test_tokenizer(multilingual):
  5. tokenizer = get_tokenizer(multilingual=False)
  6. assert tokenizer.sot in tokenizer.sot_sequence
  7. assert len(tokenizer.all_language_codes) == len(tokenizer.all_language_tokens)
  8. assert all(c < tokenizer.timestamp_begin for c in tokenizer.all_language_tokens)
  9. def test_multilingual_tokenizer():
  10. gpt2_tokenizer = get_tokenizer(multilingual=False)
  11. multilingual_tokenizer = get_tokenizer(multilingual=True)
  12. text = "다람쥐 헌 쳇바퀴에 타고파"
  13. gpt2_tokens = gpt2_tokenizer.encode(text)
  14. multilingual_tokens = multilingual_tokenizer.encode(text)
  15. assert gpt2_tokenizer.decode(gpt2_tokens) == text
  16. assert multilingual_tokenizer.decode(multilingual_tokens) == text
  17. assert len(gpt2_tokens) > len(multilingual_tokens)
  18. def test_split_on_unicode():
  19. multilingual_tokenizer = get_tokenizer(multilingual=True)
  20. tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
  21. words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
  22. assert words == [" elle", " est", " l", "'", "\ufffd", "é", "rit", "oire"]
  23. assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]