__init__.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import hashlib
  2. import io
  3. import os
  4. import urllib
  5. import warnings
  6. from typing import List, Optional, Union
  7. import torch
  8. from tqdm import tqdm
  9. from .audio import load_audio, log_mel_spectrogram, pad_or_trim
  10. from .decoding import DecodingOptions, DecodingResult, decode, detect_language
  11. from .model import ModelDimensions, Whisper
  12. from .transcribe import transcribe
  13. from .version import __version__
  14. _MODELS = {
  15. "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
  16. "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
  17. "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
  18. "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
  19. "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
  20. "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
  21. "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
  22. "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
  23. "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
  24. "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
  25. "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
  26. }
  27. # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
  28. # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
  29. _ALIGNMENT_HEADS = {
  30. "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
  31. "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
  32. "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
  33. "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
  34. "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
  35. "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
  36. "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
  37. "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
  38. "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
  39. "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
  40. "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
  41. }
  42. def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
  43. os.makedirs(root, exist_ok=True)
  44. expected_sha256 = url.split("/")[-2]
  45. download_target = os.path.join(root, os.path.basename(url))
  46. if os.path.exists(download_target) and not os.path.isfile(download_target):
  47. raise RuntimeError(f"{download_target} exists and is not a regular file")
  48. if os.path.isfile(download_target):
  49. with open(download_target, "rb") as f:
  50. model_bytes = f.read()
  51. if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
  52. return model_bytes if in_memory else download_target
  53. else:
  54. warnings.warn(
  55. f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
  56. )
  57. with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  58. with tqdm(
  59. total=int(source.info().get("Content-Length")),
  60. ncols=80,
  61. unit="iB",
  62. unit_scale=True,
  63. unit_divisor=1024,
  64. ) as loop:
  65. while True:
  66. buffer = source.read(8192)
  67. if not buffer:
  68. break
  69. output.write(buffer)
  70. loop.update(len(buffer))
  71. model_bytes = open(download_target, "rb").read()
  72. if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
  73. raise RuntimeError(
  74. "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
  75. )
  76. return model_bytes if in_memory else download_target
  77. def available_models() -> List[str]:
  78. """Returns the names of available models"""
  79. return list(_MODELS.keys())
  80. def load_model(
  81. name: str,
  82. device: Optional[Union[str, torch.device]] = None,
  83. download_root: str = None,
  84. in_memory: bool = False,
  85. ) -> Whisper:
  86. """
  87. Load a Whisper ASR model
  88. Parameters
  89. ----------
  90. name : str
  91. one of the official model names listed by `whisper.available_models()`, or
  92. path to a model checkpoint containing the model dimensions and the model state_dict.
  93. device : Union[str, torch.device]
  94. the PyTorch device to put the model into
  95. download_root: str
  96. path to download the model files; by default, it uses "~/.cache/whisper"
  97. in_memory: bool
  98. whether to preload the model weights into host memory
  99. Returns
  100. -------
  101. model : Whisper
  102. The Whisper ASR model instance
  103. """
  104. if device is None:
  105. device = "cuda" if torch.cuda.is_available() else "cpu"
  106. if download_root is None:
  107. default = os.path.join(os.path.expanduser("~"), ".cache")
  108. download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
  109. if name in _MODELS:
  110. checkpoint_file = _download(_MODELS[name], download_root, in_memory)
  111. alignment_heads = _ALIGNMENT_HEADS[name]
  112. elif os.path.isfile(name):
  113. checkpoint_file = open(name, "rb").read() if in_memory else name
  114. alignment_heads = None
  115. else:
  116. raise RuntimeError(
  117. f"Model {name} not found; available models = {available_models()}"
  118. )
  119. with (
  120. io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
  121. ) as fp:
  122. checkpoint = torch.load(fp, map_location=device)
  123. del checkpoint_file
  124. dims = ModelDimensions(**checkpoint["dims"])
  125. model = Whisper(dims)
  126. model.load_state_dict(checkpoint["model_state_dict"])
  127. if alignment_heads is not None:
  128. model.set_alignment_heads(alignment_heads)
  129. return model.to(device)