| 
														
															@@ -62,6 +62,7 @@ class MultiHeadAttention(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         self.key = Linear(n_state, n_state, bias=False) 
														 | 
														
														 | 
														
															         self.key = Linear(n_state, n_state, bias=False) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         self.value = Linear(n_state, n_state) 
														 | 
														
														 | 
														
															         self.value = Linear(n_state, n_state) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         self.out = Linear(n_state, n_state) 
														 | 
														
														 | 
														
															         self.out = Linear(n_state, n_state) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        self.last_qk = None 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def forward( 
														 | 
														
														 | 
														
															     def forward( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         self, 
														 | 
														
														 | 
														
															         self, 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -96,6 +97,8 @@ class MultiHeadAttention(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if mask is not None: 
														 | 
														
														 | 
														
															         if mask is not None: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             qk = qk + mask[:n_ctx, :n_ctx] 
														 | 
														
														 | 
														
															             qk = qk + mask[:n_ctx, :n_ctx] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        self.last_qk = qk.detach() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         w = F.softmax(qk.float(), dim=-1).to(q.dtype) 
														 | 
														
														 | 
														
															         w = F.softmax(qk.float(), dim=-1).to(q.dtype) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 
														 | 
														
														 | 
														
															         return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 |