Data-Mining-Lec4

Attention Mechanism

Attention block: takes a sequence of vectors, \(x_1, ... x_L (\mathbb{R}^d)\), and output another series \(x_1, ... x_L (\mathbb{R}^d)\)

Within the block, \(W_Q, W_K, W_V\), \(X = \begin{bmatrix}x_1\\x_2\\...\\x_L\end{bmatrix} \in \mathbb{R}^{L\times d}\) \[ \mathbb{R}^{L\times d}\begin{cases} Q = W_QX\\ K = W_KX\\V=W_VX\end{cases} \] \[ Q = \begin{bmatrix}q_1\\q_2\\...\\q_L\end{bmatrix} \] all these weight matrices will be learned.

rows:\(q_1, ..., q_L\), columns: \(k_1, ..., k_L\) compose a orthogonal matrix A. \[ A_{ij} = K(q_i, k_j) \]

...

Not normalized attention: \(AV\)

The rows would be embeddings

\[\begin{bmatrix} x_1'\\...\\x_L'\end{bmatrix} = AV \in \mathbb{R}^{L\times L}\] \[ A:L\times L; V: L\times d\] \[ x_i' = \sum_{j=1}^{L}{K(q_ik_j)v_j} \]

Normalized attention:

\[ x_i' = \sum_{j=1}^{L}{\frac{K(q_ik_j)v_j}{\sum_{s=1}^{L}{K(q_i, k_s)}}} \] The weight sum up to 1 and >=0; \(K: \mathbb{R}^d\times \mathbb{R}^d: \mathbb{R}^+\)

\[D^{-1}AV\] D: diagonal L*L,

partition function: \[ d_i = \sum_{s=1}^{L}{K(q_i, k_s)} \]

Unidirectional Attention

Not-normalized setting:

\[ A_{masked}V \]

masked attention matrix, zero out uptriangular part.

\[ A_{masked} =...\]

Normalized:

\[D_{masked}^{-1}A_{masked}V\]

\[D: diagonal: L*L\]

\[ d_i = \sum_{s=1}^{i}{K(q_i, k_s)} \]

Problems with standard attention algorithm

  • Time and space complecity for computing attention is quadratic in L (cannot be used for very long sequences)

Sparsification:

  • attend just to a few tokens(either learned or fixed)
    • in a unidirectional case (lower triangular part is non-zero) last \(l\) tokens (like a column till the diagonal);
    • in a bidirectional case, closest \(l\) tokens. (Like a diagonal strip)

for Graph data, people often attend only to neighbors (graph attention methods).

  • attend to a few tokens, but learn those that you would like to attend to.

How to choose what \(k\)s to attend to? Close ones. Choose 10 closests.

if \(Q=K(W_Q = W_K)\), can cluster queries into groups

  • clustering
  • hashing, nearest neighbour approach, 10 closest neighbours, code the query to reduce complexity

Efficient Dense Attention

Approximate the attention? While the matrix is still dense? Decomposition.

Bidirectional not-normalized attention

\(AV\)

Let's try to rewrite \(A\) as \(A\approx F_1\cdot F_2 \cdot... \cdot F_L\) for some simpler matrices \(F_i\)

\[ (F_1\cdot F_2\cdot...\cdot F_L)\cdot V = F_1\cdot (F_2\cdot(...\cdot(F_L\cdot V)...) \] \[ AV \in \mathbb{R}^{L \times d} \]

Q: Can we rewrite \(A = F_1 \cdot F_2, L\times L = L \times m * m\times L\), where \(m < L\) \(A\) can be full rank (\(\det A \neq 0\))

large matrices on the diagonal, small every where else

Look at the rank, \(F_1, F_2\) has rank \(\leq m\), so no solution.

What aboyut random \(F_i\)?

  • Can we find random matrices \(F_1\) and \(F_2\) such that:
    • \(A = E[F_1F_2]\)
    • the error of the approximation small

\[ A_{ij} = e^{\frac{q_ik_j^T}{\sqrt{d}}} = e^{\frac{q_i}{d^{\frac{1}{4}}}}... \]

We conclude that

\[ A_{ij} = E[\phi(\bar{q_i})\phi(\bar{k_j}^T)] \]

\[ Q' = \begin{bmatrix} \phi(\bar{q_1})\\...\\\phi(\bar{q_L}) \end{bmatrix} \] \[ K' = \begin{bmatrix} \phi(\bar{k_1})\\...\\\phi(\bar{k_L}) \end{bmatrix} \]

\[ Q'(K')^T_{ij} = \phi(\bar{q_i})\phi(\bar{k_j})^T\] \[ E[Q'(K')^T_{ij}] = E[\phi(\bar{q_i})\phi(\bar{k_j})^T]\] \[ A = E[Q'(K')^T]\]

\[ \phi_{SM}(x) = e^{\frac{||x||^2}{2}}\phi_{Gauss}(x) = e^{\frac{||x||^2}{2}} \frac{1}{\sqrt{m}} \begin{bmatrix}\cos(w_1^Ty)\\...\\\cos(w_m^Ty)\\\sin(w_1^Ty)\\...\\\sin(w_m^Ty) \end{bmatrix} \]

\[ AV = E[Q'(K')^T]V = E[Q'(K')^TV] \]

\[ AV \approx Q'K'^TV = Q'(K'^TV)\]

\[ Q\in \mathbb{R}^{L\times m}, K' \in \mathbb{R}^{m\times L}, V \in \mathbb{R}^{L\times d} \] \[ m = 2v \] space complexity: \(O(L\times d + m\times d + L\times m)\)

time complexity: \(O(mLd)\) versus \(O(L^2d)\) for standard attention complexity.

if \(m << L\)

Problems with trignometric features

for a row, \[ \sum_{s=1}^{L}{K(q_i, ks)}\] if lots of entries in a row are close to 0 then lots of estimators' values could be potentially negative

In general, as long as kernel used for attention can be written as

\[ K(x,y) = E[\phi(x)\phi(y)^T] \]

for some \(\phi: \mathbb{R}^{d} \to \mathbb{R}^{m} (m < \infty)\) deterministic or random, we can get attention computation mechanism with $O(mLd) time complexity and O(mL+Ld+md) space complexity

Remark:

if random features are used, random features for attention computation should be periodically redrawn in downstream algorithms using attention-based models.

Remark:

in practice, it suffices to take \(m: O(d\log d)\) to ahve accurate estimation of the attention matrix.

as long as \(d << L\), presented mechanism provides space and time complexity gains.