__init__.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import hashlib
  2. import io
  3. import os
  4. import urllib
  5. import warnings
  6. from typing import List, Optional, Union
  7. import torch
  8. from tqdm import tqdm
  9. from .audio import load_audio, log_mel_spectrogram, pad_or_trim
  10. from .decoding import DecodingOptions, DecodingResult, decode, detect_language
  11. from .model import ModelDimensions, Whisper
  12. from .transcribe import transcribe
  13. from .version import __version__
  14. _MODELS = {
  15. "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
  16. "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
  17. "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
  18. "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
  19. "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
  20. "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
  21. "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
  22. "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
  23. "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
  24. "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
  25. "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
  26. "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
  27. }
  28. # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
  29. # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
  30. _ALIGNMENT_HEADS = {
  31. "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
  32. "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
  33. "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
  34. "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
  35. "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
  36. "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
  37. "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
  38. "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
  39. "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
  40. "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
  41. "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
  42. "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
  43. }
  44. def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
  45. os.makedirs(root, exist_ok=True)
  46. expected_sha256 = url.split("/")[-2]
  47. download_target = os.path.join(root, os.path.basename(url))
  48. if os.path.exists(download_target) and not os.path.isfile(download_target):
  49. raise RuntimeError(f"{download_target} exists and is not a regular file")
  50. if os.path.isfile(download_target):
  51. with open(download_target, "rb") as f:
  52. model_bytes = f.read()
  53. if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
  54. return model_bytes if in_memory else download_target
  55. else:
  56. warnings.warn(
  57. f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
  58. )
  59. with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  60. with tqdm(
  61. total=int(source.info().get("Content-Length")),
  62. ncols=80,
  63. unit="iB",
  64. unit_scale=True,
  65. unit_divisor=1024,
  66. ) as loop:
  67. while True:
  68. buffer = source.read(8192)
  69. if not buffer:
  70. break
  71. output.write(buffer)
  72. loop.update(len(buffer))
  73. model_bytes = open(download_target, "rb").read()
  74. if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
  75. raise RuntimeError(
  76. "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
  77. )
  78. return model_bytes if in_memory else download_target
  79. def available_models() -> List[str]:
  80. """Returns the names of available models"""
  81. return list(_MODELS.keys())
  82. def load_model(
  83. name: str,
  84. device: Optional[Union[str, torch.device]] = None,
  85. download_root: str = None,
  86. in_memory: bool = False,
  87. ) -> Whisper:
  88. """
  89. Load a Whisper ASR model
  90. Parameters
  91. ----------
  92. name : str
  93. one of the official model names listed by `whisper.available_models()`, or
  94. path to a model checkpoint containing the model dimensions and the model state_dict.
  95. device : Union[str, torch.device]
  96. the PyTorch device to put the model into
  97. download_root: str
  98. path to download the model files; by default, it uses "~/.cache/whisper"
  99. in_memory: bool
  100. whether to preload the model weights into host memory
  101. Returns
  102. -------
  103. model : Whisper
  104. The Whisper ASR model instance
  105. """
  106. if device is None:
  107. device = "cuda" if torch.cuda.is_available() else "cpu"
  108. if download_root is None:
  109. default = os.path.join(os.path.expanduser("~"), ".cache")
  110. download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
  111. if name in _MODELS:
  112. checkpoint_file = _download(_MODELS[name], download_root, in_memory)
  113. alignment_heads = _ALIGNMENT_HEADS[name]
  114. elif os.path.isfile(name):
  115. checkpoint_file = open(name, "rb").read() if in_memory else name
  116. alignment_heads = None
  117. else:
  118. raise RuntimeError(
  119. f"Model {name} not found; available models = {available_models()}"
  120. )
  121. with (
  122. io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
  123. ) as fp:
  124. checkpoint = torch.load(fp, map_location=device)
  125. del checkpoint_file
  126. dims = ModelDimensions(**checkpoint["dims"])
  127. model = Whisper(dims)
  128. model.load_state_dict(checkpoint["model_state_dict"])
  129. if alignment_heads is not None:
  130. model.set_alignment_heads(alignment_heads)
  131. return model.to(device)