tokenizer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import os
  2. import string
  3. from dataclasses import dataclass
  4. from functools import cached_property, lru_cache
  5. from typing import List, Optional, Tuple, Union
  6. import numpy as np
  7. import torch
  8. from transformers import GPT2TokenizerFast
  9. LANGUAGES = {
  10. "en": "english",
  11. "zh": "chinese",
  12. "de": "german",
  13. "es": "spanish",
  14. "ru": "russian",
  15. "ko": "korean",
  16. "fr": "french",
  17. "ja": "japanese",
  18. "pt": "portuguese",
  19. "tr": "turkish",
  20. "pl": "polish",
  21. "ca": "catalan",
  22. "nl": "dutch",
  23. "ar": "arabic",
  24. "sv": "swedish",
  25. "it": "italian",
  26. "id": "indonesian",
  27. "hi": "hindi",
  28. "fi": "finnish",
  29. "vi": "vietnamese",
  30. "he": "hebrew",
  31. "uk": "ukrainian",
  32. "el": "greek",
  33. "ms": "malay",
  34. "cs": "czech",
  35. "ro": "romanian",
  36. "da": "danish",
  37. "hu": "hungarian",
  38. "ta": "tamil",
  39. "no": "norwegian",
  40. "th": "thai",
  41. "ur": "urdu",
  42. "hr": "croatian",
  43. "bg": "bulgarian",
  44. "lt": "lithuanian",
  45. "la": "latin",
  46. "mi": "maori",
  47. "ml": "malayalam",
  48. "cy": "welsh",
  49. "sk": "slovak",
  50. "te": "telugu",
  51. "fa": "persian",
  52. "lv": "latvian",
  53. "bn": "bengali",
  54. "sr": "serbian",
  55. "az": "azerbaijani",
  56. "sl": "slovenian",
  57. "kn": "kannada",
  58. "et": "estonian",
  59. "mk": "macedonian",
  60. "br": "breton",
  61. "eu": "basque",
  62. "is": "icelandic",
  63. "hy": "armenian",
  64. "ne": "nepali",
  65. "mn": "mongolian",
  66. "bs": "bosnian",
  67. "kk": "kazakh",
  68. "sq": "albanian",
  69. "sw": "swahili",
  70. "gl": "galician",
  71. "mr": "marathi",
  72. "pa": "punjabi",
  73. "si": "sinhala",
  74. "km": "khmer",
  75. "sn": "shona",
  76. "yo": "yoruba",
  77. "so": "somali",
  78. "af": "afrikaans",
  79. "oc": "occitan",
  80. "ka": "georgian",
  81. "be": "belarusian",
  82. "tg": "tajik",
  83. "sd": "sindhi",
  84. "gu": "gujarati",
  85. "am": "amharic",
  86. "yi": "yiddish",
  87. "lo": "lao",
  88. "uz": "uzbek",
  89. "fo": "faroese",
  90. "ht": "haitian creole",
  91. "ps": "pashto",
  92. "tk": "turkmen",
  93. "nn": "nynorsk",
  94. "mt": "maltese",
  95. "sa": "sanskrit",
  96. "lb": "luxembourgish",
  97. "my": "myanmar",
  98. "bo": "tibetan",
  99. "tl": "tagalog",
  100. "mg": "malagasy",
  101. "as": "assamese",
  102. "tt": "tatar",
  103. "haw": "hawaiian",
  104. "ln": "lingala",
  105. "ha": "hausa",
  106. "ba": "bashkir",
  107. "jw": "javanese",
  108. "su": "sundanese",
  109. }
  110. # language code lookup by name, with a few language aliases
  111. TO_LANGUAGE_CODE = {
  112. **{language: code for code, language in LANGUAGES.items()},
  113. "burmese": "my",
  114. "valencian": "ca",
  115. "flemish": "nl",
  116. "haitian": "ht",
  117. "letzeburgesch": "lb",
  118. "pushto": "ps",
  119. "panjabi": "pa",
  120. "moldavian": "ro",
  121. "moldovan": "ro",
  122. "sinhalese": "si",
  123. "castilian": "es",
  124. }
  125. @dataclass(frozen=True)
  126. class Tokenizer:
  127. """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
  128. tokenizer: "GPT2TokenizerFast"
  129. language: Optional[str]
  130. sot_sequence: Tuple[int]
  131. def encode(self, text, **kwargs):
  132. return self.tokenizer.encode(text, **kwargs)
  133. def decode(
  134. self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
  135. ):
  136. return self.tokenizer.decode(token_ids, **kwargs)
  137. def decode_with_timestamps(self, tokens) -> str:
  138. """
  139. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
  140. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  141. """
  142. outputs = [[]]
  143. for token in tokens:
  144. if token >= self.timestamp_begin:
  145. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  146. outputs.append(timestamp)
  147. outputs.append([])
  148. else:
  149. outputs[-1].append(token)
  150. return "".join(
  151. [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
  152. )
  153. @cached_property
  154. def eot(self) -> int:
  155. return self.tokenizer.eos_token_id
  156. @cached_property
  157. def transcribe(self) -> int:
  158. return self._get_single_token_id("<|transcribe|>")
  159. @cached_property
  160. def translate(self) -> int:
  161. return self._get_single_token_id("<|translate|>")
  162. @cached_property
  163. def sot(self) -> int:
  164. return self._get_single_token_id("<|startoftranscript|>")
  165. @cached_property
  166. def sot_lm(self) -> int:
  167. return self._get_single_token_id("<|startoflm|>")
  168. @cached_property
  169. def sot_prev(self) -> int:
  170. return self._get_single_token_id("<|startofprev|>")
  171. @cached_property
  172. def no_speech(self) -> int:
  173. return self._get_single_token_id("<|nospeech|>")
  174. @cached_property
  175. def no_timestamps(self) -> int:
  176. return self._get_single_token_id("<|notimestamps|>")
  177. @cached_property
  178. def timestamp_begin(self) -> int:
  179. return self.tokenizer.all_special_ids[-1] + 1
  180. @cached_property
  181. def language_token(self) -> int:
  182. """Returns the token id corresponding to the value of the `language` field"""
  183. if self.language is None:
  184. raise ValueError("This tokenizer does not have language token configured")
  185. additional_tokens = dict(
  186. zip(
  187. self.tokenizer.additional_special_tokens,
  188. self.tokenizer.additional_special_tokens_ids,
  189. )
  190. )
  191. candidate = f"<|{self.language}|>"
  192. if candidate in additional_tokens:
  193. return additional_tokens[candidate]
  194. raise KeyError(f"Language {self.language} not found in tokenizer.")
  195. @cached_property
  196. def all_language_tokens(self) -> Tuple[int]:
  197. result = []
  198. for token, token_id in zip(
  199. self.tokenizer.additional_special_tokens,
  200. self.tokenizer.additional_special_tokens_ids,
  201. ):
  202. if token.strip("<|>") in LANGUAGES:
  203. result.append(token_id)
  204. return tuple(result)
  205. @cached_property
  206. def all_language_codes(self) -> Tuple[str]:
  207. return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
  208. @cached_property
  209. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  210. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  211. @cached_property
  212. def non_speech_tokens(self) -> Tuple[int]:
  213. """
  214. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  215. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  216. - ♪♪♪
  217. - ( SPEAKING FOREIGN LANGUAGE )
  218. - [DAVID] Hey there,
  219. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  220. """
  221. symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
  222. symbols += (
  223. "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  224. )
  225. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  226. # In case they're multiple tokens, suppress the first token, which is safe because:
  227. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  228. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  229. miscellaneous = set("♩♪♫♬♭♮♯")
  230. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  231. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  232. result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
  233. for symbol in symbols + list(miscellaneous):
  234. for tokens in [
  235. self.tokenizer.encode(symbol),
  236. self.tokenizer.encode(" " + symbol),
  237. ]:
  238. if len(tokens) == 1 or symbol in miscellaneous:
  239. result.add(tokens[0])
  240. return tuple(sorted(result))
  241. def _get_single_token_id(self, text) -> int:
  242. tokens = self.tokenizer.encode(text)
  243. assert len(tokens) == 1, f"{text} is not encoded as a single token"
  244. return tokens[0]
  245. def split_to_word_tokens(self, tokens: List[int]):
  246. if self.language in {"zh", "ja", "th", "lo", "my"}:
  247. # These languages don't typically use spaces, so it is difficult to split words
  248. # without morpheme analysis. Here, we instead split words at any
  249. # position where the tokens are decoded as valid unicode points
  250. return self.split_tokens_on_unicode(tokens)
  251. return self.split_tokens_on_spaces(tokens)
  252. def split_tokens_on_unicode(self, tokens: List[int]):
  253. words = []
  254. word_tokens = []
  255. current_tokens = []
  256. for token in tokens:
  257. current_tokens.append(token)
  258. decoded = self.decode_with_timestamps(current_tokens)
  259. if "\ufffd" not in decoded:
  260. words.append(decoded)
  261. word_tokens.append(current_tokens)
  262. current_tokens = []
  263. return words, word_tokens
  264. def split_tokens_on_spaces(self, tokens: List[int]):
  265. subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
  266. words = []
  267. word_tokens = []
  268. for subword, subword_tokens in zip(subwords, subword_tokens_list):
  269. special = subword_tokens[0] >= self.eot
  270. with_space = subword.startswith(" ")
  271. punctuation = subword.strip() in string.punctuation
  272. if special or with_space or punctuation or len(words) == 0:
  273. words.append(subword)
  274. word_tokens.append(subword_tokens)
  275. else:
  276. words[-1] = words[-1] + subword
  277. word_tokens[-1].extend(subword_tokens)
  278. return words, word_tokens
  279. @lru_cache(maxsize=None)
  280. def build_tokenizer(name: str = "gpt2"):
  281. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  282. path = os.path.join(os.path.dirname(__file__), "assets", name)
  283. tokenizer = GPT2TokenizerFast.from_pretrained(path)
  284. specials = [
  285. "<|startoftranscript|>",
  286. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  287. "<|translate|>",
  288. "<|transcribe|>",
  289. "<|startoflm|>",
  290. "<|startofprev|>",
  291. "<|nospeech|>",
  292. "<|notimestamps|>",
  293. ]
  294. tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  295. return tokenizer
  296. @lru_cache(maxsize=None)
  297. def get_tokenizer(
  298. multilingual: bool,
  299. *,
  300. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  301. language: Optional[str] = None,
  302. ) -> Tokenizer:
  303. if language is not None:
  304. language = language.lower()
  305. if language not in LANGUAGES:
  306. if language in TO_LANGUAGE_CODE:
  307. language = TO_LANGUAGE_CODE[language]
  308. else:
  309. raise ValueError(f"Unsupported language: {language}")
  310. if multilingual:
  311. tokenizer_name = "multilingual"
  312. task = task or "transcribe"
  313. language = language or "en"
  314. else:
  315. tokenizer_name = "gpt2"
  316. task = None
  317. language = None
  318. tokenizer = build_tokenizer(name=tokenizer_name)
  319. all_special_ids: List[int] = tokenizer.all_special_ids
  320. sot: int = all_special_ids[1]
  321. translate: int = all_special_ids[-6]
  322. transcribe: int = all_special_ids[-5]
  323. langs = tuple(LANGUAGES.keys())
  324. sot_sequence = [sot]
  325. if language is not None:
  326. sot_sequence.append(sot + 1 + langs.index(language))
  327. if task is not None:
  328. sot_sequence.append(transcribe if task == "transcribe" else translate)
  329. return Tokenizer(
  330. tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
  331. )