AICurious Logo

What is: Fixed Factorized Attention?

SourceGenerating Long Sequences with Sparse Transformers
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

Fixed Factorized Attention is a factorized attention pattern where specific cells summarize previous locations and propagate that information to all future cells. It was proposed as part of the Sparse Transformer architecture.

A self-attention layer maps a matrix of input embeddings XX to an output matrix and is parameterized by a connectivity pattern S=set(S_1,,S_n)S = \text{set}\left(S\_{1}, \dots, S\_{n}\right), where S_iS\_{i} denotes the set of indices of the input vectors to which the iith output vector attends. The output vector is a weighted sum of transformations of the input vectors:

Attend(X,S)=(a(x_i,S_i))_iset(1,,n) \text{Attend}\left(X, S\right) = \left(a\left(\mathbf{x}\_{i}, S\_{i}\right)\right)\_{i\in\text{set}\left(1,\dots,n\right)}

a(x_i,S_i)=softmax((W_qx_i)KT_S_id)V_S_ia\left(\mathbf{x}\_{i}, S\_{i}\right) = \text{softmax}\left(\frac{\left(W\_{q}\mathbf{x}\_{i}\right)K^{T}\_{S\_{i}}}{\sqrt{d}}\right)V\_{S\_{i}}

K_Si=(W_kx_j)_jS_iK\_{Si} = \left(W\_{k}\mathbf{x}\_{j}\right)\_{j\in{S\_{i}}}

V_Si=(W_vx_j)_jS_iV\_{Si} = \left(W\_{v}\mathbf{x}\_{j}\right)\_{j\in{S\_{i}}}

Here W_qW\_{q}, W_kW\_{k}, and W_vW\_{v} represent the weight matrices which transform a given x_ix\_{i} into a query, key, or value, and dd is the inner dimension of the queries and keys. The output at each position is a sum of the values weighted by the scaled dot-product similarity of the keys and queries.

Full self-attention for autoregressive models defines S_i=set(j:ji)S\_{i} = \text{set}\left(j : j \leq i\right), allowing every element to attend to all previous positions and its own position.

Factorized self-attention instead has pp separate attention heads, where the mmth head defines a subset of the indices A_i(m)set(j:ji)A\_{i}^{(m)} ⊂ \text{set}\left(j : j \leq i\right) and lets S_i=A_i(m)S\_{i} = A\_{i}^{(m)}. The goal with the Sparse Transformer was to find efficient choices for the subset AA.

Formally for Fixed Factorized Attention, A(1)_i=A^{(1)}\_{i} = {j:(j/l=i/l)j : \left(\lfloor{j/l\rfloor}=\lfloor{i/l\rfloor}\right)}, where the brackets denote the floor operation, and A(2)_i=A^{(2)}\_{i} = {j:jmodlj : j \mod l \in {t,t+1,,lt, t+1, \ldots, l}}, where t=lct=l-c and cc is a hyperparameter. The ii-th output vector of the attention head attends to all input vectors either from A(1)_iA^{(1)}\_{i} or A(2)_iA^{(2)}\_{i}. This pattern can be visualized in the figure to the right.

If the stride is 128 and c=8c = 8, then all future positions greater than 128 can attend to positions 120-128, all positions greater than 256 can attend to 248-256, and so forth.

A fixed-attention pattern with c=1c = 1 limits the expressivity of the network significantly, as many representations in the network are only used for one block whereas a small number of locations are used by all blocks. The authors found choosing cc \in {8,16,328, 16, 32} for typical values of l128,256l \in {128, 256} performs well, although this increases the computational cost of this method by cc in comparison to the strided attention.

Additionally, the authors found that when using multiple heads, having them attend to distinct subblocks of length cc within the block of size ll was preferable to having them attend to the same subblock.