Browse Source

Avoid keeping redundant copies of model weights in memory during load (#42)

* don't keep copies of model weights in host memory

* adding type annotation

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
Niklas K 2 years ago
parent
commit
f296bcd3fa
1 changed files with 19 additions and 15 deletions
  1. 19 15
      whisper/__init__.py

+ 19 - 15
whisper/__init__.py

@@ -27,12 +27,11 @@ _MODELS = {
 }
 
 
-def _download(url: str, root: str) -> bytes:
+def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
     os.makedirs(root, exist_ok=True)
-    filename = os.path.basename(url)
 
     expected_sha256 = url.split("/")[-2]
-    download_target = os.path.join(root, filename)
+    download_target = os.path.join(root, os.path.basename(url))
 
     if os.path.exists(download_target) and not os.path.isfile(download_target):
         raise RuntimeError(f"{download_target} exists and is not a regular file")
@@ -40,7 +39,7 @@ def _download(url: str, root: str) -> bytes:
     if os.path.isfile(download_target):
         model_bytes = open(download_target, "rb").read()
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
-            return model_bytes
+            return model_bytes if in_memory else download_target
         else:
             warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
 
@@ -58,7 +57,7 @@ def _download(url: str, root: str) -> bytes:
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
         raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
 
-    return model_bytes
+    return model_bytes if in_memory else download_target
 
 
 def available_models() -> List[str]:
@@ -66,7 +65,7 @@ def available_models() -> List[str]:
     return list(_MODELS.keys())
 
 
-def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper:
+def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
     """
     Load a Whisper ASR model
 
@@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
         the PyTorch device to put the model into
     download_root: str
         path to download the model files; by default, it uses "~/.cache/whisper"
+    in_memory: bool
+        whether to preload the model weights into host memory
 
     Returns
     -------
     model : Whisper
         The Whisper ASR model instance
     """
+
+    if device is None:
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+    if download_root is None:
+        download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
+
     if name in _MODELS:
-        model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper"))
+        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
     elif os.path.isfile(name):
-        model_bytes = open(name, "rb").read()
+        checkpoint_file = open(name, "rb").read() if in_memory else name
     else:
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
 
-    with io.BytesIO(model_bytes) as fp:
-        checkpoint = torch.load(fp, map_location="cpu")
+    with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
+        checkpoint = torch.load(fp, map_location=device)
+    del checkpoint_file
 
     dims = ModelDimensions(**checkpoint["dims"])
-    state_dict = checkpoint["model_state_dict"]
     model = Whisper(dims)
-    model.load_state_dict(state_dict)
-
-    if device is None:
-        device = "cuda" if torch.cuda.is_available() else "cpu"
+    model.load_state_dict(checkpoint["model_state_dict"])
 
     return model.to(device)