|
@@ -55,7 +55,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|
|
"""
|
|
|
if torch.is_tensor(array):
|
|
|
if array.shape[axis] > length:
|
|
|
- array = array.index_select(dim=axis, index=torch.arange(length))
|
|
|
+ array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
|
|
|
|
|
if array.shape[axis] < length:
|
|
|
pad_widths = [(0, 0)] * array.ndim
|