..
Stronger KNN attention
In standard self-attention:
$$ \text{Attention}(Q, K) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) $$
where $Q$, $K$, and $V$ are the query, key, and value matrices, respectively, and $d_k$ is the dimension of the keys.
KNN attention (Memorizing Transformers et al) $\text{KNN}(Q, K)$ is a cache of approximate KNN lookup KV-keys.
Let’s augment this with the standard attention weights, and shift its importance using $\beta$:
$$ \text{CombinedAttention}(Q, K) = \text{Attention}(Q, K) + \beta \times \text{KNN}(Q, K) $$
Final output:
$$ \text{Output} = \text{CombinedAttention}(Q, K) \times V $$