|
@@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module):
|
|
k = kv_cache[self.key]
|
|
k = kv_cache[self.key]
|
|
v = kv_cache[self.value]
|
|
v = kv_cache[self.value]
|
|
|
|
|
|
- wv = self.qkv_attention(q, k, v, mask)
|
|
|
|
- return self.out(wv)
|
|
|
|
|
|
+ wv, qk = self.qkv_attention(q, k, v, mask)
|
|
|
|
+ return self.out(wv), qk
|
|
|
|
|
|
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
|
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
|
n_batch, n_ctx, n_state = q.shape
|
|
n_batch, n_ctx, n_state = q.shape
|
|
@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
|
|
qk = q @ k
|
|
qk = q @ k
|
|
if mask is not None:
|
|
if mask is not None:
|
|
qk = qk + mask[:n_ctx, :n_ctx]
|
|
qk = qk + mask[:n_ctx, :n_ctx]
|
|
|
|
+ qk = qk.float()
|
|
|
|
|
|
- w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
|
|
|
- return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
|
|
|
|
|
+ w = F.softmax(qk, dim=-1).to(q.dtype)
|
|
|
|
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
|
|
|
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module):
|
|
class ResidualAttentionBlock(nn.Module):
|
|
@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
|
|
mask: Optional[Tensor] = None,
|
|
mask: Optional[Tensor] = None,
|
|
kv_cache: Optional[dict] = None,
|
|
kv_cache: Optional[dict] = None,
|
|
):
|
|
):
|
|
- x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
|
|
|
|
|
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
|
if self.cross_attn:
|
|
if self.cross_attn:
|
|
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
|
|
|
|
|
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
|
x = x + self.mlp(self.mlp_ln(x))
|
|
x = x + self.mlp(self.mlp_ln(x))
|
|
return x
|
|
return x
|
|
|
|
|