tokenizer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import base64
  2. import os
  3. import string
  4. from dataclasses import dataclass, field
  5. from functools import cached_property, lru_cache
  6. from typing import Dict, List, Optional, Tuple
  7. import tiktoken
  8. LANGUAGES = {
  9. "en": "english",
  10. "zh": "chinese",
  11. "de": "german",
  12. "es": "spanish",
  13. "ru": "russian",
  14. "ko": "korean",
  15. "fr": "french",
  16. "ja": "japanese",
  17. "pt": "portuguese",
  18. "tr": "turkish",
  19. "pl": "polish",
  20. "ca": "catalan",
  21. "nl": "dutch",
  22. "ar": "arabic",
  23. "sv": "swedish",
  24. "it": "italian",
  25. "id": "indonesian",
  26. "hi": "hindi",
  27. "fi": "finnish",
  28. "vi": "vietnamese",
  29. "he": "hebrew",
  30. "uk": "ukrainian",
  31. "el": "greek",
  32. "ms": "malay",
  33. "cs": "czech",
  34. "ro": "romanian",
  35. "da": "danish",
  36. "hu": "hungarian",
  37. "ta": "tamil",
  38. "no": "norwegian",
  39. "th": "thai",
  40. "ur": "urdu",
  41. "hr": "croatian",
  42. "bg": "bulgarian",
  43. "lt": "lithuanian",
  44. "la": "latin",
  45. "mi": "maori",
  46. "ml": "malayalam",
  47. "cy": "welsh",
  48. "sk": "slovak",
  49. "te": "telugu",
  50. "fa": "persian",
  51. "lv": "latvian",
  52. "bn": "bengali",
  53. "sr": "serbian",
  54. "az": "azerbaijani",
  55. "sl": "slovenian",
  56. "kn": "kannada",
  57. "et": "estonian",
  58. "mk": "macedonian",
  59. "br": "breton",
  60. "eu": "basque",
  61. "is": "icelandic",
  62. "hy": "armenian",
  63. "ne": "nepali",
  64. "mn": "mongolian",
  65. "bs": "bosnian",
  66. "kk": "kazakh",
  67. "sq": "albanian",
  68. "sw": "swahili",
  69. "gl": "galician",
  70. "mr": "marathi",
  71. "pa": "punjabi",
  72. "si": "sinhala",
  73. "km": "khmer",
  74. "sn": "shona",
  75. "yo": "yoruba",
  76. "so": "somali",
  77. "af": "afrikaans",
  78. "oc": "occitan",
  79. "ka": "georgian",
  80. "be": "belarusian",
  81. "tg": "tajik",
  82. "sd": "sindhi",
  83. "gu": "gujarati",
  84. "am": "amharic",
  85. "yi": "yiddish",
  86. "lo": "lao",
  87. "uz": "uzbek",
  88. "fo": "faroese",
  89. "ht": "haitian creole",
  90. "ps": "pashto",
  91. "tk": "turkmen",
  92. "nn": "nynorsk",
  93. "mt": "maltese",
  94. "sa": "sanskrit",
  95. "lb": "luxembourgish",
  96. "my": "myanmar",
  97. "bo": "tibetan",
  98. "tl": "tagalog",
  99. "mg": "malagasy",
  100. "as": "assamese",
  101. "tt": "tatar",
  102. "haw": "hawaiian",
  103. "ln": "lingala",
  104. "ha": "hausa",
  105. "ba": "bashkir",
  106. "jw": "javanese",
  107. "su": "sundanese",
  108. "yue": "cantonese",
  109. }
  110. # language code lookup by name, with a few language aliases
  111. TO_LANGUAGE_CODE = {
  112. **{language: code for code, language in LANGUAGES.items()},
  113. "burmese": "my",
  114. "valencian": "ca",
  115. "flemish": "nl",
  116. "haitian": "ht",
  117. "letzeburgesch": "lb",
  118. "pushto": "ps",
  119. "panjabi": "pa",
  120. "moldavian": "ro",
  121. "moldovan": "ro",
  122. "sinhalese": "si",
  123. "castilian": "es",
  124. "mandarin": "zh",
  125. }
  126. @dataclass
  127. class Tokenizer:
  128. """A thin wrapper around `tiktoken` providing quick access to special tokens"""
  129. encoding: tiktoken.Encoding
  130. num_languages: int
  131. language: Optional[str] = None
  132. task: Optional[str] = None
  133. sot_sequence: Tuple[int] = ()
  134. special_tokens: Dict[str, int] = field(default_factory=dict)
  135. def __post_init__(self):
  136. for special in self.encoding.special_tokens_set:
  137. special_token = self.encoding.encode_single_token(special)
  138. self.special_tokens[special] = special_token
  139. sot: int = self.special_tokens["<|startoftranscript|>"]
  140. translate: int = self.special_tokens["<|translate|>"]
  141. transcribe: int = self.special_tokens["<|transcribe|>"]
  142. langs = tuple(LANGUAGES.keys())[: self.num_languages]
  143. sot_sequence = [sot]
  144. if self.language is not None:
  145. sot_sequence.append(sot + 1 + langs.index(self.language))
  146. if self.task is not None:
  147. task_token: int = transcribe if self.task == "transcribe" else translate
  148. sot_sequence.append(task_token)
  149. self.sot_sequence = tuple(sot_sequence)
  150. def encode(self, text, **kwargs):
  151. return self.encoding.encode(text, **kwargs)
  152. def decode(self, token_ids: List[int], **kwargs) -> str:
  153. token_ids = [t for t in token_ids if t < self.timestamp_begin]
  154. return self.encoding.decode(token_ids, **kwargs)
  155. def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
  156. """
  157. Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
  158. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  159. """
  160. return self.encoding.decode(token_ids, **kwargs)
  161. @cached_property
  162. def eot(self) -> int:
  163. return self.encoding.eot_token
  164. @cached_property
  165. def transcribe(self) -> int:
  166. return self.special_tokens["<|transcribe|>"]
  167. @cached_property
  168. def translate(self) -> int:
  169. return self.special_tokens["<|translate|>"]
  170. @cached_property
  171. def sot(self) -> int:
  172. return self.special_tokens["<|startoftranscript|>"]
  173. @cached_property
  174. def sot_lm(self) -> int:
  175. return self.special_tokens["<|startoflm|>"]
  176. @cached_property
  177. def sot_prev(self) -> int:
  178. return self.special_tokens["<|startofprev|>"]
  179. @cached_property
  180. def no_speech(self) -> int:
  181. return self.special_tokens["<|nospeech|>"]
  182. @cached_property
  183. def no_timestamps(self) -> int:
  184. return self.special_tokens["<|notimestamps|>"]
  185. @cached_property
  186. def timestamp_begin(self) -> int:
  187. return self.special_tokens["<|0.00|>"]
  188. @cached_property
  189. def language_token(self) -> int:
  190. """Returns the token id corresponding to the value of the `language` field"""
  191. if self.language is None:
  192. raise ValueError("This tokenizer does not have language token configured")
  193. return self.to_language_token(self.language)
  194. def to_language_token(self, language):
  195. if token := self.special_tokens.get(f"<|{language}|>", None):
  196. return token
  197. raise KeyError(f"Language {language} not found in tokenizer.")
  198. @cached_property
  199. def all_language_tokens(self) -> Tuple[int]:
  200. result = []
  201. for token, token_id in self.special_tokens.items():
  202. if token.strip("<|>") in LANGUAGES:
  203. result.append(token_id)
  204. return tuple(result)[: self.num_languages]
  205. @cached_property
  206. def all_language_codes(self) -> Tuple[str]:
  207. return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
  208. @cached_property
  209. def sot_sequence_including_notimestamps(self) -> Tuple[int]:
  210. return tuple(list(self.sot_sequence) + [self.no_timestamps])
  211. @cached_property
  212. def non_speech_tokens(self) -> Tuple[int]:
  213. """
  214. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  215. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  216. - ♪♪♪
  217. - ( SPEAKING FOREIGN LANGUAGE )
  218. - [DAVID] Hey there,
  219. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  220. """
  221. symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
  222. symbols += (
  223. "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  224. )
  225. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  226. # In case they're multiple tokens, suppress the first token, which is safe because:
  227. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  228. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  229. miscellaneous = set("♩♪♫♬♭♮♯")
  230. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  231. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  232. result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
  233. for symbol in symbols + list(miscellaneous):
  234. for tokens in [
  235. self.encoding.encode(symbol),
  236. self.encoding.encode(" " + symbol),
  237. ]:
  238. if len(tokens) == 1 or symbol in miscellaneous:
  239. result.add(tokens[0])
  240. return tuple(sorted(result))
  241. def split_to_word_tokens(self, tokens: List[int]):
  242. if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
  243. # These languages don't typically use spaces, so it is difficult to split words
  244. # without morpheme analysis. Here, we instead split words at any
  245. # position where the tokens are decoded as valid unicode points
  246. return self.split_tokens_on_unicode(tokens)
  247. return self.split_tokens_on_spaces(tokens)
  248. def split_tokens_on_unicode(self, tokens: List[int]):
  249. decoded_full = self.decode_with_timestamps(tokens)
  250. replacement_char = "\ufffd"
  251. words = []
  252. word_tokens = []
  253. current_tokens = []
  254. unicode_offset = 0
  255. for token in tokens:
  256. current_tokens.append(token)
  257. decoded = self.decode_with_timestamps(current_tokens)
  258. if (
  259. replacement_char not in decoded
  260. or decoded_full[unicode_offset + decoded.index(replacement_char)]
  261. == replacement_char
  262. ):
  263. words.append(decoded)
  264. word_tokens.append(current_tokens)
  265. current_tokens = []
  266. unicode_offset += len(decoded)
  267. return words, word_tokens
  268. def split_tokens_on_spaces(self, tokens: List[int]):
  269. subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
  270. words = []
  271. word_tokens = []
  272. for subword, subword_tokens in zip(subwords, subword_tokens_list):
  273. special = subword_tokens[0] >= self.eot
  274. with_space = subword.startswith(" ")
  275. punctuation = subword.strip() in string.punctuation
  276. if special or with_space or punctuation or len(words) == 0:
  277. words.append(subword)
  278. word_tokens.append(subword_tokens)
  279. else:
  280. words[-1] = words[-1] + subword
  281. word_tokens[-1].extend(subword_tokens)
  282. return words, word_tokens
  283. @lru_cache(maxsize=None)
  284. def get_encoding(name: str = "gpt2", num_languages: int = 99):
  285. vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
  286. ranks = {
  287. base64.b64decode(token): int(rank)
  288. for token, rank in (line.split() for line in open(vocab_path) if line)
  289. }
  290. n_vocab = len(ranks)
  291. special_tokens = {}
  292. specials = [
  293. "<|endoftext|>",
  294. "<|startoftranscript|>",
  295. *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
  296. "<|translate|>",
  297. "<|transcribe|>",
  298. "<|startoflm|>",
  299. "<|startofprev|>",
  300. "<|nospeech|>",
  301. "<|notimestamps|>",
  302. *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
  303. ]
  304. for token in specials:
  305. special_tokens[token] = n_vocab
  306. n_vocab += 1
  307. return tiktoken.Encoding(
  308. name=os.path.basename(vocab_path),
  309. explicit_n_vocab=n_vocab,
  310. pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
  311. mergeable_ranks=ranks,
  312. special_tokens=special_tokens,
  313. )
  314. @lru_cache(maxsize=None)
  315. def get_tokenizer(
  316. multilingual: bool,
  317. *,
  318. num_languages: int = 99,
  319. language: Optional[str] = None,
  320. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  321. ) -> Tokenizer:
  322. if language is not None:
  323. language = language.lower()
  324. if language not in LANGUAGES:
  325. if language in TO_LANGUAGE_CODE:
  326. language = TO_LANGUAGE_CODE[language]
  327. else:
  328. raise ValueError(f"Unsupported language: {language}")
  329. if multilingual:
  330. encoding_name = "multilingual"
  331. language = language or "en"
  332. task = task or "transcribe"
  333. else:
  334. encoding_name = "gpt2"
  335. language = None
  336. task = None
  337. encoding = get_encoding(name=encoding_name, num_languages=num_languages)
  338. return Tokenizer(
  339. encoding=encoding, num_languages=num_languages, language=language, task=task
  340. )