Explorar o código

Fix bug (#305)

Fix bug: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
Michael Monashev %!s(int64=2) %!d(string=hai) anos
pai
achega
f680570016
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      whisper/audio.py

+ 1 - 1
whisper/audio.py

@@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
     """
     if torch.is_tensor(array):
         if array.shape[axis] > length:
-            array = array.index_select(dim=axis, index=torch.arange(length))
+            array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
 
         if array.shape[axis] < length:
             pad_widths = [(0, 0)] * array.ndim