| 
					
				 | 
			
			
				@@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         q = self.query(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if kv_cache is None or xa is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if kv_cache is None or xa is None or self.key not in kv_cache: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # otherwise, perform key/value projections for self- or cross-attention as usual. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             k = self.key(x if xa is None else xa) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             v = self.value(x if xa is None else xa) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # for cross-attention, calculate keys and values once and reuse in subsequent calls. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            k = kv_cache.get(self.key, self.key(xa)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            v = kv_cache.get(self.value, self.value(xa)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            k = kv_cache[self.key] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            v = kv_cache[self.value] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         wv = self.qkv_attention(q, k, v, mask) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return self.out(wv) 
			 |