Browse Source

Avoid keeping redundant copies of model weights in memory during load (#42)

* don't keep copies of model weights in host memory

* adding type annotation

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
Niklas K 2 years ago
parent
commit
f296bcd3fa
1 changed files with 19 additions and 15 deletions
  1. 19 15
      whisper/__init__.py

+ 19 - 15
whisper/__init__.py

@@ -27,12 +27,11 @@ _MODELS = {
 }
 }
 
 
 
 
-def _download(url: str, root: str) -> bytes:
+def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
     os.makedirs(root, exist_ok=True)
     os.makedirs(root, exist_ok=True)
-    filename = os.path.basename(url)
 
 
     expected_sha256 = url.split("/")[-2]
     expected_sha256 = url.split("/")[-2]
-    download_target = os.path.join(root, filename)
+    download_target = os.path.join(root, os.path.basename(url))
 
 
     if os.path.exists(download_target) and not os.path.isfile(download_target):
     if os.path.exists(download_target) and not os.path.isfile(download_target):
         raise RuntimeError(f"{download_target} exists and is not a regular file")
         raise RuntimeError(f"{download_target} exists and is not a regular file")
@@ -40,7 +39,7 @@ def _download(url: str, root: str) -> bytes:
     if os.path.isfile(download_target):
     if os.path.isfile(download_target):
         model_bytes = open(download_target, "rb").read()
         model_bytes = open(download_target, "rb").read()
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
         if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
-            return model_bytes
+            return model_bytes if in_memory else download_target
         else:
         else:
             warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
             warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
 
 
@@ -58,7 +57,7 @@ def _download(url: str, root: str) -> bytes:
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
     if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
         raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
         raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
 
 
-    return model_bytes
+    return model_bytes if in_memory else download_target
 
 
 
 
 def available_models() -> List[str]:
 def available_models() -> List[str]:
@@ -66,7 +65,7 @@ def available_models() -> List[str]:
     return list(_MODELS.keys())
     return list(_MODELS.keys())
 
 
 
 
-def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper:
+def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
     """
     """
     Load a Whisper ASR model
     Load a Whisper ASR model
 
 
@@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
         the PyTorch device to put the model into
         the PyTorch device to put the model into
     download_root: str
     download_root: str
         path to download the model files; by default, it uses "~/.cache/whisper"
         path to download the model files; by default, it uses "~/.cache/whisper"
+    in_memory: bool
+        whether to preload the model weights into host memory
 
 
     Returns
     Returns
     -------
     -------
     model : Whisper
     model : Whisper
         The Whisper ASR model instance
         The Whisper ASR model instance
     """
     """
+
+    if device is None:
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+    if download_root is None:
+        download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
+
     if name in _MODELS:
     if name in _MODELS:
-        model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper"))
+        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
     elif os.path.isfile(name):
     elif os.path.isfile(name):
-        model_bytes = open(name, "rb").read()
+        checkpoint_file = open(name, "rb").read() if in_memory else name
     else:
     else:
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
         raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
 
 
-    with io.BytesIO(model_bytes) as fp:
-        checkpoint = torch.load(fp, map_location="cpu")
+    with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
+        checkpoint = torch.load(fp, map_location=device)
+    del checkpoint_file
 
 
     dims = ModelDimensions(**checkpoint["dims"])
     dims = ModelDimensions(**checkpoint["dims"])
-    state_dict = checkpoint["model_state_dict"]
     model = Whisper(dims)
     model = Whisper(dims)
-    model.load_state_dict(state_dict)
-
-    if device is None:
-        device = "cuda" if torch.cuda.is_available() else "cpu"
+    model.load_state_dict(checkpoint["model_state_dict"])
 
 
     return model.to(device)
     return model.to(device)