Ver código fonte

Fix attention caching to make it actually work (#370)

Vicki Anand 2 anos atrás
pai
commit
9f70a352f9
1 arquivos alterados com 3 adições e 3 exclusões
  1. 3 3
      whisper/model.py

+ 3 - 3
whisper/model.py

@@ -72,15 +72,15 @@ class MultiHeadAttention(nn.Module):
     ):
         q = self.query(x)
 
-        if kv_cache is None or xa is None:
+        if kv_cache is None or xa is None or self.key not in kv_cache:
             # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
             # otherwise, perform key/value projections for self- or cross-attention as usual.
             k = self.key(x if xa is None else xa)
             v = self.value(x if xa is None else xa)
         else:
             # for cross-attention, calculate keys and values once and reuse in subsequent calls.
-            k = kv_cache.get(self.key, self.key(xa))
-            v = kv_cache.get(self.value, self.value(xa))
+            k = kv_cache[self.key]
+            v = kv_cache[self.value]
 
         wv = self.qkv_attention(q, k, v, mask)
         return self.out(wv)