Browse Source

MultiHeadAttention to return qk as well

Jong Wook Kim 2 năm trước cách đây
mục cha
commit
53807677fe
1 tập tin đã thay đổi với 7 bổ sung6 xóa
  1. 7 6
      whisper/model.py

+ 7 - 6
whisper/model.py

@@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module):
             k = kv_cache[self.key]
             v = kv_cache[self.value]
 
-        wv = self.qkv_attention(q, k, v, mask)
-        return self.out(wv)
+        wv, qk = self.qkv_attention(q, k, v, mask)
+        return self.out(wv), qk
 
     def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
         n_batch, n_ctx, n_state = q.shape
@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
         qk = q @ k
         if mask is not None:
             qk = qk + mask[:n_ctx, :n_ctx]
+        qk = qk.float()
 
-        w = F.softmax(qk.float(), dim=-1).to(q.dtype)
-        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
+        w = F.softmax(qk, dim=-1).to(q.dtype)
+        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
 
 
 class ResidualAttentionBlock(nn.Module):
@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
         mask: Optional[Tensor] = None,
         kv_cache: Optional[dict] = None,
     ):
-        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
+        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
         if self.cross_attn:
-            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
+            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
         x = x + self.mlp(self.mlp_ln(x))
         return x