|
@@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module):
|
|
|
):
|
|
|
q = self.query(x)
|
|
|
|
|
|
- if kv_cache is None or xa is None:
|
|
|
+ if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
|
k = self.key(x if xa is None else xa)
|
|
|
v = self.value(x if xa is None else xa)
|
|
|
else:
|
|
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
|
- k = kv_cache.get(self.key, self.key(xa))
|
|
|
- v = kv_cache.get(self.value, self.value(xa))
|
|
|
+ k = kv_cache[self.key]
|
|
|
+ v = kv_cache[self.value]
|
|
|
|
|
|
wv = self.qkv_attention(q, k, v, mask)
|
|
|
return self.out(wv)
|