| 
														
															@@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ): 
														 | 
														
														 | 
														
															     ): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         q = self.query(x) 
														 | 
														
														 | 
														
															         q = self.query(x) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        if kv_cache is None or xa is None: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if kv_cache is None or xa is None or self.key not in kv_cache: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 
														 | 
														
														 | 
														
															             # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             # otherwise, perform key/value projections for self- or cross-attention as usual. 
														 | 
														
														 | 
														
															             # otherwise, perform key/value projections for self- or cross-attention as usual. 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             k = self.key(x if xa is None else xa) 
														 | 
														
														 | 
														
															             k = self.key(x if xa is None else xa) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             v = self.value(x if xa is None else xa) 
														 | 
														
														 | 
														
															             v = self.value(x if xa is None else xa) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         else: 
														 | 
														
														 | 
														
															         else: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             # for cross-attention, calculate keys and values once and reuse in subsequent calls. 
														 | 
														
														 | 
														
															             # for cross-attention, calculate keys and values once and reuse in subsequent calls. 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            k = kv_cache.get(self.key, self.key(xa)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            v = kv_cache.get(self.value, self.value(xa)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            k = kv_cache[self.key] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            v = kv_cache[self.value] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         wv = self.qkv_attention(q, k, v, mask) 
														 | 
														
														 | 
														
															         wv = self.qkv_attention(q, k, v, mask) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         return self.out(wv) 
														 | 
														
														 | 
														
															         return self.out(wv) 
														 |