浏览代码

Avoid keeping redundant copies of model weights in memory during load (#42)

* don't keep copies of model weights in host memory

* adding type annotation

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
Niklas K 3 年之前
父节点
当前提交
f296bcd3fa
共有 1 个文件被更改,包括 19 次插入15 次删除
  1. 19 15
      whisper/__init__.py

+ 19 - 15
whisper/__init__.py

@@ -27,12 +27,11 @@ _MODELS = {
 }
 }
 
 
 
 
-def _download(url: str, root: str) -> bytes:
+def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
     os.makedirs(root, exist_ok=True)
     os.makedirs(root, exist_ok=True)
-    filename = os.path.basename(url)
 
 
     expected_sha256 = url.split("/")[-2]
     expected_sha256 = url.split("/")[-2]
-    download_target = os.path.join(root, filename)
+    download_target = os.path.join(root, os.path.basename(url))
 
 
     if os.path.exists(download_target) and not os.path.isfile(download_target):
     if os.path.exists(download_target) and not os.path.isfile(download_target):
         raise RuntimeError(f"{download_target} exists and is not a regular file")
         raise RuntimeError(f"{download_target} exists and is not a regular file")
@@ -40,7 +39,7 @@ def _download(url: str, root: str) -> bytes:
     if os.path.isfile(download_target):
     if os.path.isfile(download_target):
         model_bytes = open(download_target, "rb").read()
         model_bytes = open(download_target, "rb").read()
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
-            return model_bytes
+            return model_bytes if in_memory else download_target
         else:
         else:
             warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
             warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
 
 
@@ -58,7 +57,7 @@ def _download(url: str, root: str) -> bytes:
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
         raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
         raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
 
 
-    return model_bytes
+    return model_bytes if in_memory else download_target
 
 
 
 
 def available_models() -> List[str]:
 def available_models() -> List[str]:
@@ -66,7 +65,7 @@ def available_models() -> List[str]:
     return list(_MODELS.keys())
     return list(_MODELS.keys())
 
 
 
 
-def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper:
+def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
     """
     """
     Load a Whisper ASR model
     Load a Whisper ASR model
 
 
@@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
         the PyTorch device to put the model into
         the PyTorch device to put the model into
     download_root: str
     download_root: str
         path to download the model files; by default, it uses "~/.cache/whisper"
         path to download the model files; by default, it uses "~/.cache/whisper"
+    in_memory: bool
+        whether to preload the model weights into host memory
 
 
     Returns
     Returns
     -------
     -------
     model : Whisper
     model : Whisper
         The Whisper ASR model instance
         The Whisper ASR model instance
     """
     """
+
+    if device is None:
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+    if download_root is None:
+        download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
+
     if name in _MODELS:
     if name in _MODELS:
-        model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper"))
+        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
     elif os.path.isfile(name):
     elif os.path.isfile(name):
-        model_bytes = open(name, "rb").read()
+        checkpoint_file = open(name, "rb").read() if in_memory else name
     else:
     else:
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
 
 
-    with io.BytesIO(model_bytes) as fp:
-        checkpoint = torch.load(fp, map_location="cpu")
+    with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
+        checkpoint = torch.load(fp, map_location=device)
+    del checkpoint_file
 
 
     dims = ModelDimensions(**checkpoint["dims"])
     dims = ModelDimensions(**checkpoint["dims"])
-    state_dict = checkpoint["model_state_dict"]
     model = Whisper(dims)
     model = Whisper(dims)
-    model.load_state_dict(state_dict)
-
-    if device is None:
-        device = "cuda" if torch.cuda.is_available() else "cpu"
+    model.load_state_dict(checkpoint["model_state_dict"])
 
 
     return model.to(device)
     return model.to(device)