Przeglądaj źródła

Fix bug (#305)

Fix bug: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
Michael Monashev 2 lat temu
rodzic
commit
f680570016
1 zmienionych plików z 1 dodań i 1 usunięć
  1. 1 1
      whisper/audio.py

+ 1 - 1
whisper/audio.py

@@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
     """
     if torch.is_tensor(array):
         if array.shape[axis] > length:
-            array = array.index_select(dim=axis, index=torch.arange(length))
+            array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
 
         if array.shape[axis] < length:
             pad_widths = [(0, 0)] * array.ndim