Pārlūkot izejas kodu

Support batch-dimension in log_mel_spectogram (#839)

Markus Hennerbichler 2 gadi atpakaļ
vecāks
revīzija
6df3ea1fb5
1 mainītis faili ar 1 papildinājumiem un 1 dzēšanām
  1. 1 1
      whisper/audio.py

+ 1 - 1
whisper/audio.py

@@ -113,7 +113,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
 
     window = torch.hann_window(N_FFT).to(audio.device)
     stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
-    magnitudes = stft[:, :-1].abs() ** 2
+    magnitudes = stft[..., :-1].abs() ** 2
 
     filters = mel_filters(audio.device, n_mels)
     mel_spec = filters @ magnitudes