123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- import os
- from dataclasses import dataclass
- from functools import lru_cache
- from typing import List, Optional, Tuple, Union
- import numpy as np
- import torch
- from transformers import GPT2TokenizerFast
- LANGUAGES = {
- "en": "english",
- "zh": "chinese",
- "de": "german",
- "es": "spanish",
- "ru": "russian",
- "ko": "korean",
- "fr": "french",
- "ja": "japanese",
- "pt": "portuguese",
- "tr": "turkish",
- "pl": "polish",
- "ca": "catalan",
- "nl": "dutch",
- "ar": "arabic",
- "sv": "swedish",
- "it": "italian",
- "id": "indonesian",
- "hi": "hindi",
- "fi": "finnish",
- "vi": "vietnamese",
- "iw": "hebrew",
- "uk": "ukrainian",
- "el": "greek",
- "ms": "malay",
- "cs": "czech",
- "ro": "romanian",
- "da": "danish",
- "hu": "hungarian",
- "ta": "tamil",
- "no": "norwegian",
- "th": "thai",
- "ur": "urdu",
- "hr": "croatian",
- "bg": "bulgarian",
- "lt": "lithuanian",
- "la": "latin",
- "mi": "maori",
- "ml": "malayalam",
- "cy": "welsh",
- "sk": "slovak",
- "te": "telugu",
- "fa": "persian",
- "lv": "latvian",
- "bn": "bengali",
- "sr": "serbian",
- "az": "azerbaijani",
- "sl": "slovenian",
- "kn": "kannada",
- "et": "estonian",
- "mk": "macedonian",
- "br": "breton",
- "eu": "basque",
- "is": "icelandic",
- "hy": "armenian",
- "ne": "nepali",
- "mn": "mongolian",
- "bs": "bosnian",
- "kk": "kazakh",
- "sq": "albanian",
- "sw": "swahili",
- "gl": "galician",
- "mr": "marathi",
- "pa": "punjabi",
- "si": "sinhala",
- "km": "khmer",
- "sn": "shona",
- "yo": "yoruba",
- "so": "somali",
- "af": "afrikaans",
- "oc": "occitan",
- "ka": "georgian",
- "be": "belarusian",
- "tg": "tajik",
- "sd": "sindhi",
- "gu": "gujarati",
- "am": "amharic",
- "yi": "yiddish",
- "lo": "lao",
- "uz": "uzbek",
- "fo": "faroese",
- "ht": "haitian creole",
- "ps": "pashto",
- "tk": "turkmen",
- "nn": "nynorsk",
- "mt": "maltese",
- "sa": "sanskrit",
- "lb": "luxembourgish",
- "my": "myanmar",
- "bo": "tibetan",
- "tl": "tagalog",
- "mg": "malagasy",
- "as": "assamese",
- "tt": "tatar",
- "haw": "hawaiian",
- "ln": "lingala",
- "ha": "hausa",
- "ba": "bashkir",
- "jw": "javanese",
- "su": "sundanese",
- }
- # language code lookup by name, with a few language aliases
- TO_LANGUAGE_CODE = {
- **{language: code for code, language in LANGUAGES.items()},
- "burmese": "my",
- "valencian": "ca",
- "flemish": "nl",
- "haitian": "ht",
- "letzeburgesch": "lb",
- "pushto": "ps",
- "panjabi": "pa",
- "moldavian": "ro",
- "moldovan": "ro",
- "sinhalese": "si",
- "castilian": "es",
- }
- @dataclass(frozen=True)
- class Tokenizer:
- """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
- tokenizer: "GPT2TokenizerFast"
- language: Optional[str]
- sot_sequence: Tuple[int]
- def encode(self, text, **kwargs):
- return self.tokenizer.encode(text, **kwargs)
- def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
- return self.tokenizer.decode(token_ids, **kwargs)
- def decode_with_timestamps(self, tokens) -> str:
- """
- Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
- This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
- """
- outputs = [[]]
- for token in tokens:
- if token >= self.timestamp_begin:
- timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
- outputs.append(timestamp)
- outputs.append([])
- else:
- outputs[-1].append(token)
- outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
- return "".join(outputs)
- @property
- @lru_cache()
- def eot(self) -> int:
- return self.tokenizer.eos_token_id
- @property
- @lru_cache()
- def sot(self) -> int:
- return self._get_single_token_id("<|startoftranscript|>")
- @property
- @lru_cache()
- def sot_lm(self) -> int:
- return self._get_single_token_id("<|startoflm|>")
- @property
- @lru_cache()
- def sot_prev(self) -> int:
- return self._get_single_token_id("<|startofprev|>")
- @property
- @lru_cache()
- def no_speech(self) -> int:
- return self._get_single_token_id("<|nospeech|>")
- @property
- @lru_cache()
- def no_timestamps(self) -> int:
- return self._get_single_token_id("<|notimestamps|>")
- @property
- @lru_cache()
- def timestamp_begin(self) -> int:
- return self.tokenizer.all_special_ids[-1] + 1
- @property
- @lru_cache()
- def language_token(self) -> int:
- """Returns the token id corresponding to the value of the `language` field"""
- if self.language is None:
- raise ValueError(f"This tokenizer does not have language token configured")
- additional_tokens = dict(
- zip(
- self.tokenizer.additional_special_tokens,
- self.tokenizer.additional_special_tokens_ids,
- )
- )
- candidate = f"<|{self.language}|>"
- if candidate in additional_tokens:
- return additional_tokens[candidate]
- raise KeyError(f"Language {self.language} not found in tokenizer.")
- @property
- @lru_cache()
- def all_language_tokens(self) -> Tuple[int]:
- result = []
- for token, token_id in zip(
- self.tokenizer.additional_special_tokens,
- self.tokenizer.additional_special_tokens_ids,
- ):
- if token.strip("<|>") in LANGUAGES:
- result.append(token_id)
- return tuple(result)
- @property
- @lru_cache()
- def all_language_codes(self) -> Tuple[str]:
- return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
- @property
- @lru_cache()
- def sot_sequence_including_notimestamps(self) -> Tuple[int]:
- return tuple(list(self.sot_sequence) + [self.no_timestamps])
- @property
- @lru_cache()
- def non_speech_tokens(self) -> Tuple[int]:
- """
- Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
- annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- - ♪♪♪
- - ( SPEAKING FOREIGN LANGUAGE )
- - [DAVID] Hey there,
- keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
- """
- symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
- symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
- # symbols that may be a single token or multiple tokens depending on the tokenizer.
- # In case they're multiple tokens, suppress the first token, which is safe because:
- # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
- # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
- miscellaneous = set("♩♪♫♬♭♮♯")
- assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
- # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
- result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
- for symbol in symbols + list(miscellaneous):
- for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
- if len(tokens) == 1 or symbol in miscellaneous:
- result.add(tokens[0])
- return tuple(sorted(result))
- def _get_single_token_id(self, text) -> int:
- tokens = self.tokenizer.encode(text)
- assert len(tokens) == 1, f"{text} is not encoded as a single token"
- return tokens[0]
- @lru_cache(maxsize=None)
- def build_tokenizer(name: str = "gpt2"):
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- path = os.path.join(os.path.dirname(__file__), "assets", name)
- tokenizer = GPT2TokenizerFast.from_pretrained(path)
- specials = [
- "<|startoftranscript|>",
- *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nospeech|>",
- "<|notimestamps|>",
- ]
- tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
- return tokenizer
- @lru_cache(maxsize=None)
- def get_tokenizer(
- multilingual: bool,
- *,
- task: Optional[str] = None, # Literal["transcribe", "translate", None]
- language: Optional[str] = None,
- ) -> Tokenizer:
- if language is not None:
- language = language.lower()
- if language not in LANGUAGES:
- if language in TO_LANGUAGE_CODE:
- language = TO_LANGUAGE_CODE[language]
- else:
- raise ValueError(f"Unsupported language: {language}")
- if multilingual:
- tokenizer_name = "multilingual"
- task = task or "transcribe"
- language = language or "en"
- else:
- tokenizer_name = "gpt2"
- task = None
- language = None
- tokenizer = build_tokenizer(name=tokenizer_name)
- all_special_ids: List[int] = tokenizer.all_special_ids
- sot: int = all_special_ids[1]
- translate: int = all_special_ids[-6]
- transcribe: int = all_special_ids[-5]
- langs = tuple(LANGUAGES.keys())
- sot_sequence = [sot]
- if language is not None:
- sot_sequence.append(sot + 1 + langs.index(language))
- if task is not None:
- sot_sequence.append(transcribe if task == "transcribe" else translate)
- return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|