فهرست منبع

Use PyTorch as logits transpose for ONNX support (#141)

Michael Goin 2 سال پیش
والد
کامیت
9c8183a179
1فایلهای تغییر یافته به همراه1 افزوده شده و 1 حذف شده
  1. 1 1
      whisper/model.py

+ 1 - 1
whisper/model.py

@@ -189,7 +189,7 @@ class TextDecoder(nn.Module):
             x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
 
         x = self.ln(x)
-        logits = (x @ self.token_embedding.weight.to(x.dtype).T).float()
+        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
 
         return logits