|
@@ -214,10 +214,10 @@ class Whisper(nn.Module):
|
|
|
)
|
|
|
|
|
|
def embed_audio(self, mel: torch.Tensor):
|
|
|
- return self.encoder.forward(mel)
|
|
|
+ return self.encoder(mel)
|
|
|
|
|
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
|
|
- return self.decoder.forward(tokens, audio_features)
|
|
|
+ return self.decoder(tokens, audio_features)
|
|
|
|
|
|
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
|
return self.decoder(tokens, self.encoder(mel))
|