..

Stronger KNN attention

In standard self-attention:

$$ \text{Attention}(Q, K) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) $$

where $Q$, $K$, and $V$ are the query, key, and value matrices, respectively, and $d_k$ is the dimension of the keys.

KNN attention (Memorizing Transformers et al) $\text{KNN}(Q, K)$ is a cache of approximate KNN lookup KV-keys.

Let’s augment this with the standard attention weights, and shift its importance using $\beta$:

$$ \text{CombinedAttention}(Q, K) = \text{Attention}(Q, K) + \beta \times \text{KNN}(Q, K) $$

Final output:

$$ \text{Output} = \text{CombinedAttention}(Q, K) \times V $$