| 
														
															@@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             k = kv_cache[self.key] 
														 | 
														
														 | 
														
															             k = kv_cache[self.key] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             v = kv_cache[self.value] 
														 | 
														
														 | 
														
															             v = kv_cache[self.value] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        wv = self.qkv_attention(q, k, v, mask) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return self.out(wv) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        wv, qk = self.qkv_attention(q, k, v, mask) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return self.out(wv), qk 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): 
														 | 
														
														 | 
														
															     def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         n_batch, n_ctx, n_state = q.shape 
														 | 
														
														 | 
														
															         n_batch, n_ctx, n_state = q.shape 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         qk = q @ k 
														 | 
														
														 | 
														
															         qk = q @ k 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if mask is not None: 
														 | 
														
														 | 
														
															         if mask is not None: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             qk = qk + mask[:n_ctx, :n_ctx] 
														 | 
														
														 | 
														
															             qk = qk + mask[:n_ctx, :n_ctx] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        qk = qk.float() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        w = F.softmax(qk.float(), dim=-1).to(q.dtype) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        w = F.softmax(qk, dim=-1).to(q.dtype) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 class ResidualAttentionBlock(nn.Module): 
														 | 
														
														 | 
														
															 class ResidualAttentionBlock(nn.Module): 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         mask: Optional[Tensor] = None, 
														 | 
														
														 | 
														
															         mask: Optional[Tensor] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         kv_cache: Optional[dict] = None, 
														 | 
														
														 | 
														
															         kv_cache: Optional[dict] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ): 
														 | 
														
														 | 
														
															     ): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if self.cross_attn: 
														 | 
														
														 | 
														
															         if self.cross_attn: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         x = x + self.mlp(self.mlp_ln(x)) 
														 | 
														
														 | 
														
															         x = x + self.mlp(self.mlp_ln(x)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         return x 
														 | 
														
														 | 
														
															         return x 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 |