Explorar o código

Fixed CoW RuntimeError in DecodingTask.run() (#240)

Corentin Jemine %!s(int64=2) %!d(string=hai) anos
pai
achega
9e653bd0ea
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      whisper/decoding.py

+ 1 - 1
whisper/decoding.py

@@ -615,7 +615,7 @@ class DecodingTask:
         n_audio: int = mel.shape[0]
 
         audio_features: Tensor = self._get_audio_features(mel)  # encoder forward pass
-        tokens: Tensor = torch.tensor([self.initial_tokens]).expand(n_audio, -1)
+        tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
 
         # detect language if requested, overwriting the language token
         languages, language_probs = self._detect_language(audio_features, tokens)