| 
					
				 | 
			
			
				@@ -62,7 +62,6 @@ class MultiHeadAttention(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.key = Linear(n_state, n_state, bias=False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.value = Linear(n_state, n_state) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.out = Linear(n_state, n_state) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.last_qk = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def forward( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -97,8 +96,6 @@ class MultiHeadAttention(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if mask is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             qk = qk + mask[:n_ctx, :n_ctx] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.last_qk = qk.detach() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         w = F.softmax(qk.float(), dim=-1).to(q.dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |