| 
					
				 | 
			
			
				@@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             k = kv_cache[self.key] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             v = kv_cache[self.value] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        wv = self.qkv_attention(q, k, v, mask) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return self.out(wv) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        wv, qk = self.qkv_attention(q, k, v, mask) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return self.out(wv), qk 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         n_batch, n_ctx, n_state = q.shape 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         qk = q @ k 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if mask is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             qk = qk + mask[:n_ctx, :n_ctx] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        qk = qk.float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        w = F.softmax(qk.float(), dim=-1).to(q.dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        w = F.softmax(qk, dim=-1).to(q.dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class ResidualAttentionBlock(nn.Module): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         mask: Optional[Tensor] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         kv_cache: Optional[dict] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if self.cross_attn: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x = x + self.mlp(self.mlp_ln(x)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return x 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |