tokenizer.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import os
  2. from dataclasses import dataclass
  3. from functools import lru_cache
  4. from typing import List, Optional, Tuple, Union
  5. import numpy as np
  6. import torch
  7. from transformers import GPT2TokenizerFast
  8. LANGUAGES = {
  9. "en": "english",
  10. "zh": "chinese",
  11. "de": "german",
  12. "es": "spanish",
  13. "ru": "russian",
  14. "ko": "korean",
  15. "fr": "french",
  16. "ja": "japanese",
  17. "pt": "portuguese",
  18. "tr": "turkish",
  19. "pl": "polish",
  20. "ca": "catalan",
  21. "nl": "dutch",
  22. "ar": "arabic",
  23. "sv": "swedish",
  24. "it": "italian",
  25. "id": "indonesian",
  26. "hi": "hindi",
  27. "fi": "finnish",
  28. "vi": "vietnamese",
  29. "iw": "hebrew",
  30. "uk": "ukrainian",
  31. "el": "greek",
  32. "ms": "malay",
  33. "cs": "czech",
  34. "ro": "romanian",
  35. "da": "danish",
  36. "hu": "hungarian",
  37. "ta": "tamil",
  38. "no": "norwegian",
  39. "th": "thai",
  40. "ur": "urdu",
  41. "hr": "croatian",
  42. "bg": "bulgarian",
  43. "lt": "lithuanian",
  44. "la": "latin",
  45. "mi": "maori",
  46. "ml": "malayalam",
  47. "cy": "welsh",
  48. "sk": "slovak",
  49. "te": "telugu",
  50. "fa": "persian",
  51. "lv": "latvian",
  52. "bn": "bengali",
  53. "sr": "serbian",
  54. "az": "azerbaijani",
  55. "sl": "slovenian",
  56. "kn": "kannada",
  57. "et": "estonian",
  58. "mk": "macedonian",
  59. "br": "breton",
  60. "eu": "basque",
  61. "is": "icelandic",
  62. "hy": "armenian",
  63. "ne": "nepali",
  64. "mn": "mongolian",
  65. "bs": "bosnian",
  66. "kk": "kazakh",
  67. "sq": "albanian",
  68. "sw": "swahili",
  69. "gl": "galician",
  70. "mr": "marathi",
  71. "pa": "punjabi",
  72. "si": "sinhala",
  73. "km": "khmer",
  74. "sn": "shona",
  75. "yo": "yoruba",
  76. "so": "somali",
  77. "af": "afrikaans",
  78. "oc": "occitan",
  79. "ka": "georgian",
  80. "be": "belarusian",
  81. "tg": "tajik",
  82. "sd": "sindhi",
  83. "gu": "gujarati",
  84. "am": "amharic",
  85. "yi": "yiddish",
  86. "lo": "lao",
  87. "uz": "uzbek",
  88. "fo": "faroese",
  89. "ht": "haitian creole",
  90. "ps": "pashto",
  91. "tk": "turkmen",
  92. "nn": "nynorsk",
  93. "mt": "maltese",
  94. "sa": "sanskrit",
  95. "lb": "luxembourgish",
  96. "my": "myanmar",
  97. "bo": "tibetan",
  98. "tl": "tagalog",
  99. "mg": "malagasy",
  100. "as": "assamese",
  101. "tt": "tatar",
  102. "haw": "hawaiian",
  103. "ln": "lingala",
  104. "ha": "hausa",
  105. "ba": "bashkir",
  106. "jw": "javanese",
  107. "su": "sundanese",
  108. }
  109. # language code lookup by name, with a few language aliases
  110. TO_LANGUAGE_CODE = {
  111. **{language: code for code, language in LANGUAGES.items()},
  112. "burmese": "my",
  113. "valencian": "ca",
  114. "flemish": "nl",
  115. "haitian": "ht",
  116. "letzeburgesch": "lb",
  117. "pushto": "ps",
  118. "panjabi": "pa",
  119. "moldavian": "ro",
  120. "moldovan": "ro",
  121. "sinhalese": "si",
  122. "castilian": "es",
  123. }
  124. @dataclass(frozen=True)
  125. class Tokenizer:
  126. """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
  127. tokenizer: "GPT2TokenizerFast"
  128. language: Optional[str]
  129. sot_sequence: Tuple[int]
  130. def encode(self, text, **kwargs):
  131. return self.tokenizer.encode(text, **kwargs)
  132. def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
  133. return self.tokenizer.decode(token_ids, **kwargs)
  134. def decode_with_timestamps(self, tokens) -> str:
  135. """
  136. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
  137. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  138. """
  139. outputs = [[]]
  140. for token in tokens:
  141. if token >= self.timestamp_begin:
  142. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  143. outputs.append(timestamp)
  144. outputs.append([])
  145. else:
  146. outputs[-1].append(token)
  147. outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
  148. return "".join(outputs)
  149. @property
  150. @lru_cache()
  151. def eot(self) -> int:
  152. return self.tokenizer.eos_token_id
  153. @property
  154. @lru_cache()
  155. def sot(self) -> int:
  156. return self._get_single_token_id("<|startoftranscript|>")
  157. @property
  158. @lru_cache()
  159. def sot_lm(self) -> int:
  160. return self._get_single_token_id("<|startoflm|>")
  161. @property
  162. @lru_cache()
  163. def sot_prev(self) -> int:
  164. return self._get_single_token_id("<|startofprev|>")
  165. @property
  166. @lru_cache()
  167. def no_speech(self) -> int:
  168. return self._get_single_token_id("<|nospeech|>")
  169. @property
  170. @lru_cache()
  171. def no_timestamps(self) -> int:
  172. return self._get_single_token_id("<|notimestamps|>")
  173. @property
  174. @lru_cache()
  175. def timestamp_begin(self) -> int:
  176. return self.tokenizer.all_special_ids[-1] + 1
  177. @property
  178. @lru_cache()
  179. def language_token(self) -> int:
  180. """Returns the token id corresponding to the value of the `language` field"""
  181. if self.language is None:
  182. raise ValueError(f"This tokenizer does not have language token configured")
  183. additional_tokens = dict(
  184. zip(
  185. self.tokenizer.additional_special_tokens,
  186. self.tokenizer.additional_special_tokens_ids,
  187. )
  188. )
  189. candidate = f"<|{self.language}|>"
  190. if candidate in additional_tokens:
  191. return additional_tokens[candidate]
  192. raise KeyError(f"Language {self.language} not found in tokenizer.")
  193. @property
  194. @lru_cache()
  195. def all_language_tokens(self) -> Tuple[int]:
  196. result = []
  197. for token, token_id in zip(
  198. self.tokenizer.additional_special_tokens,
  199. self.tokenizer.additional_special_tokens_ids,
  200. ):
  201. if token.strip("<|>") in LANGUAGES:
  202. result.append(token_id)
  203. return tuple(result)
  204. @property
  205. @lru_cache()
  206. def all_language_codes(self) -> Tuple[str]:
  207. return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
  208. @property
  209. @lru_cache()
  210. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  211. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  212. @property
  213. @lru_cache()
  214. def non_speech_tokens(self) -> Tuple[int]:
  215. """
  216. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  217. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  218. - ♪♪♪
  219. - ( SPEAKING FOREIGN LANGUAGE )
  220. - [DAVID] Hey there,
  221. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  222. """
  223. symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
  224. symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  225. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  226. # In case they're multiple tokens, suppress the first token, which is safe because:
  227. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  228. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  229. miscellaneous = set("♩♪♫♬♭♮♯")
  230. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  231. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  232. result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
  233. for symbol in symbols + list(miscellaneous):
  234. for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
  235. if len(tokens) == 1 or symbol in miscellaneous:
  236. result.add(tokens[0])
  237. return tuple(sorted(result))
  238. def _get_single_token_id(self, text) -> int:
  239. tokens = self.tokenizer.encode(text)
  240. assert len(tokens) == 1, f"{text} is not encoded as a single token"
  241. return tokens[0]
  242. @lru_cache(maxsize=None)
  243. def build_tokenizer(name: str = "gpt2"):
  244. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  245. path = os.path.join(os.path.dirname(__file__), "assets", name)
  246. tokenizer = GPT2TokenizerFast.from_pretrained(path)
  247. specials = [
  248. "<|startoftranscript|>",
  249. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  250. "<|translate|>",
  251. "<|transcribe|>",
  252. "<|startoflm|>",
  253. "<|startofprev|>",
  254. "<|nospeech|>",
  255. "<|notimestamps|>",
  256. ]
  257. tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  258. return tokenizer
  259. @lru_cache(maxsize=None)
  260. def get_tokenizer(
  261. multilingual: bool,
  262. *,
  263. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  264. language: Optional[str] = None,
  265. ) -> Tokenizer:
  266. if language is not None:
  267. language = language.lower()
  268. if language not in LANGUAGES:
  269. if language in TO_LANGUAGE_CODE:
  270. language = TO_LANGUAGE_CODE[language]
  271. else:
  272. raise ValueError(f"Unsupported language: {language}")
  273. if multilingual:
  274. tokenizer_name = "multilingual"
  275. task = task or "transcribe"
  276. language = language or "en"
  277. else:
  278. tokenizer_name = "gpt2"
  279. task = None
  280. language = None
  281. tokenizer = build_tokenizer(name=tokenizer_name)
  282. all_special_ids: List[int] = tokenizer.all_special_ids
  283. sot: int = all_special_ids[1]
  284. translate: int = all_special_ids[-6]
  285. transcribe: int = all_special_ids[-5]
  286. langs = tuple(LANGUAGES.keys())
  287. sot_sequence = [sot]
  288. if language is not None:
  289. sot_sequence.append(sot + 1 + langs.index(language))
  290. if task is not None:
  291. sot_sequence.append(transcribe if task == "transcribe" else translate)
  292. return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))