Attention Mechanism
Attention block: takes a sequence of vectors, \(x_1, ... x_L (\mathbb{R}^d)\), and output another series \(x_1, ... x_L (\mathbb{R}^d)\)
Within the block, \(W_Q, W_K, W_V\), \(X = \begin{bmatrix}x_1\\x_2\\...\\x_L\end{bmatrix} \in \mathbb{R}^{L\times d}\) \[ \mathbb{R}^{L\times d}\begin{cases} Q = W_QX\\ K = W_KX\\V=W_VX\end{cases} \] \[ Q = \begin{bmatrix}q_1\\q_2\\...\\q_L\end{bmatrix} \] all these weight matrices will be learned.
rows:\(q_1, ..., q_L\), columns: \(k_1, ..., k_L\) compose a orthogonal matrix A. \[ A_{ij} = K(q_i, k_j) \]
...
Not normalized attention: \(AV\)
The rows would be embeddings
\[\begin{bmatrix} x_1'\\...\\x_L'\end{bmatrix} = AV \in \mathbb{R}^{L\times L}\] \[ A:L\times L; V: L\times d\] \[ x_i' = \sum_{j=1}^{L}{K(q_ik_j)v_j} \]
Normalized attention:
\[ x_i' = \sum_{j=1}^{L}{\frac{K(q_ik_j)v_j}{\sum_{s=1}^{L}{K(q_i, k_s)}}} \] The weight sum up to 1 and >=0; \(K: \mathbb{R}^d\times \mathbb{R}^d: \mathbb{R}^+\)
\[D^{-1}AV\] D: diagonal L*L,
partition function: \[ d_i = \sum_{s=1}^{L}{K(q_i, k_s)} \]
Unidirectional Attention
Not-normalized setting:
\[ A_{masked}V \]
masked attention matrix, zero out uptriangular part.
\[ A_{masked} =...\]
Normalized:
\[D_{masked}^{-1}A_{masked}V\]
\[D: diagonal: L*L\]
\[ d_i = \sum_{s=1}^{i}{K(q_i, k_s)} \]
Problems with standard attention algorithm
- Time and space complecity for computing attention is quadratic in L (cannot be used for very long sequences)
Sparsification:
- attend just to a few tokens(either learned or fixed)
- in a unidirectional case (lower triangular part is non-zero) last \(l\) tokens (like a column till the diagonal);
- in a bidirectional case, closest \(l\) tokens. (Like a diagonal strip)
for Graph data
, people often attend only to neighbors (graph attention methods).
- attend to a few tokens, but learn those that you would like to attend to.
How to choose what \(k\)s to attend to? Close ones. Choose 10 closests.
if \(Q=K(W_Q = W_K)\), can cluster queries into groups
- clustering
- hashing, nearest neighbour approach, 10 closest neighbours, code the query to reduce complexity
Efficient Dense Attention
Approximate the attention? While the matrix is still dense? Decomposition.
Bidirectional not-normalized attention
\(AV\)
Let's try to rewrite \(A\) as \(A\approx F_1\cdot F_2 \cdot... \cdot F_L\) for some simpler matrices \(F_i\)
\[ (F_1\cdot F_2\cdot...\cdot F_L)\cdot V = F_1\cdot (F_2\cdot(...\cdot(F_L\cdot V)...) \] \[ AV \in \mathbb{R}^{L \times d} \]
Q: Can we rewrite \(A = F_1 \cdot F_2, L\times L = L \times m * m\times L\), where \(m < L\) \(A\) can be full rank (\(\det A \neq 0\))
large matrices on the diagonal, small every where else
Look at the rank, \(F_1, F_2\) has rank \(\leq m\), so no solution.
What aboyut random \(F_i\)?
- Can we find random matrices \(F_1\) and \(F_2\) such that:
- \(A = E[F_1F_2]\)
- the error of the approximation small
\[ A_{ij} = e^{\frac{q_ik_j^T}{\sqrt{d}}} = e^{\frac{q_i}{d^{\frac{1}{4}}}}... \]
We conclude that
\[ A_{ij} = E[\phi(\bar{q_i})\phi(\bar{k_j}^T)] \]
\[ Q' = \begin{bmatrix} \phi(\bar{q_1})\\...\\\phi(\bar{q_L}) \end{bmatrix} \] \[ K' = \begin{bmatrix} \phi(\bar{k_1})\\...\\\phi(\bar{k_L}) \end{bmatrix} \]
\[ Q'(K')^T_{ij} = \phi(\bar{q_i})\phi(\bar{k_j})^T\] \[ E[Q'(K')^T_{ij}] = E[\phi(\bar{q_i})\phi(\bar{k_j})^T]\] \[ A = E[Q'(K')^T]\]
\[ \phi_{SM}(x) = e^{\frac{||x||^2}{2}}\phi_{Gauss}(x) = e^{\frac{||x||^2}{2}} \frac{1}{\sqrt{m}} \begin{bmatrix}\cos(w_1^Ty)\\...\\\cos(w_m^Ty)\\\sin(w_1^Ty)\\...\\\sin(w_m^Ty) \end{bmatrix} \]
\[ AV = E[Q'(K')^T]V = E[Q'(K')^TV] \]
\[ AV \approx Q'K'^TV = Q'(K'^TV)\]
\[ Q\in \mathbb{R}^{L\times m}, K' \in \mathbb{R}^{m\times L}, V \in \mathbb{R}^{L\times d} \] \[ m = 2v \] space complexity: \(O(L\times d + m\times d + L\times m)\)
time complexity: \(O(mLd)\) versus \(O(L^2d)\) for standard attention complexity.
if \(m << L\)
Problems with trignometric features
for a row, \[ \sum_{s=1}^{L}{K(q_i, ks)}\] if lots of entries in a row are close to 0 then lots of estimators' values could be potentially negative
In general, as long as kernel used for attention can be written as
\[ K(x,y) = E[\phi(x)\phi(y)^T] \]
for some \(\phi: \mathbb{R}^{d} \to \mathbb{R}^{m} (m < \infty)\) deterministic or random, we can get attention computation mechanism with $O(mLd) time complexity and O(mL+Ld+md) space complexity
Remark:
if random features are used, random features for attention computation should be periodically redrawn in downstream algorithms using attention-based models.
Remark:
in practice, it suffices to take \(m: O(d\log d)\) to ahve accurate estimation of the attention matrix.
as long as \(d << L\), presented mechanism provides space and time complexity gains.