__init__.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import hashlib
  2. import io
  3. import os
  4. import urllib
  5. import warnings
  6. from typing import List, Optional, Union
  7. import torch
  8. from tqdm import tqdm
  9. from .audio import load_audio, log_mel_spectrogram, pad_or_trim
  10. from .decoding import DecodingOptions, DecodingResult, decode, detect_language
  11. from .model import Whisper, ModelDimensions
  12. from .transcribe import transcribe
  13. _MODELS = {
  14. "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
  15. "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
  16. "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
  17. "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
  18. "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
  19. "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
  20. "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
  21. "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
  22. "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
  23. }
  24. def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
  25. os.makedirs(root, exist_ok=True)
  26. expected_sha256 = url.split("/")[-2]
  27. download_target = os.path.join(root, os.path.basename(url))
  28. if os.path.exists(download_target) and not os.path.isfile(download_target):
  29. raise RuntimeError(f"{download_target} exists and is not a regular file")
  30. if os.path.isfile(download_target):
  31. model_bytes = open(download_target, "rb").read()
  32. if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
  33. return model_bytes if in_memory else download_target
  34. else:
  35. warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
  36. with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  37. with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
  38. while True:
  39. buffer = source.read(8192)
  40. if not buffer:
  41. break
  42. output.write(buffer)
  43. loop.update(len(buffer))
  44. model_bytes = open(download_target, "rb").read()
  45. if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
  46. raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
  47. return model_bytes if in_memory else download_target
  48. def available_models() -> List[str]:
  49. """Returns the names of available models"""
  50. return list(_MODELS.keys())
  51. def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
  52. """
  53. Load a Whisper ASR model
  54. Parameters
  55. ----------
  56. name : str
  57. one of the official model names listed by `whisper.available_models()`, or
  58. path to a model checkpoint containing the model dimensions and the model state_dict.
  59. device : Union[str, torch.device]
  60. the PyTorch device to put the model into
  61. download_root: str
  62. path to download the model files; by default, it uses "~/.cache/whisper"
  63. in_memory: bool
  64. whether to preload the model weights into host memory
  65. Returns
  66. -------
  67. model : Whisper
  68. The Whisper ASR model instance
  69. """
  70. if device is None:
  71. device = "cuda" if torch.cuda.is_available() else "cpu"
  72. if download_root is None:
  73. download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
  74. if name in _MODELS:
  75. checkpoint_file = _download(_MODELS[name], download_root, in_memory)
  76. elif os.path.isfile(name):
  77. checkpoint_file = open(name, "rb").read() if in_memory else name
  78. else:
  79. raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
  80. with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
  81. checkpoint = torch.load(fp, map_location=device)
  82. del checkpoint_file
  83. dims = ModelDimensions(**checkpoint["dims"])
  84. model = Whisper(dims)
  85. model.load_state_dict(checkpoint["model_state_dict"])
  86. return model.to(device)