AICurious Logo

What is: SRU++?

SourceWhen Attention Meets Fast Recurrence: Training Language Models with Reduced Compute
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

SRU++ is a self-attentive recurrent unit that combines fast recurrence and attention for sequence modeling, extending the SRU unit. The key modification of SRU++ is to incorporate more expressive non-linear operations into the recurrent network. Specifically, given the input sequence represented as a matrix XRL×d\mathbf{X} \in \mathbb{R}^{L \times d}, the attention component computes the query, key and value representations using the following multiplications,

Q=WqX\mathbf{Q} =\mathbf{W}^{q} \mathbf{X}^{\top}
K=WkQ\mathbf{K} =\mathbf{W}^{k} \mathbf{Q} \\
V=WvQ\mathbf{V} =\mathbf{W}^{v} \mathbf{Q}

where WqRd×d,Wk,WvRd×d\mathbf{W}^{q} \in \mathbb{R}^{d^{\prime} \times d}, \mathbf{W}^{k}, \mathbf{W}^{v} \in \mathbb{R}^{d^{\prime} \times d^{\prime}} are model parameters. dd^{\prime} is the attention dimension that is typically much smaller than dd. Note that the keys K\mathbf{K} and values V\mathbf{V} are computed using Q\mathbf{Q} instead of X\mathbf{X} such that the weight matrices Wk\mathbf{W}^{k} and Wv\mathbf{W}^{v} are significantly smaller.

Next, we compute a weighted average output ARd×L\mathbf{A} \in \mathbb{R}^{d^{\prime} \times L} using scaled dot-product attention:

A=softmax(QKd)V\mathbf{A}^{\top}=\operatorname{softmax}\left(\frac{\mathbf{Q}^{\top} \mathbf{K}}{\sqrt{d^{\prime}}}\right) \mathbf{V}^{\top}

The final output UU required by the elementwise recurrence is obtained by another linear projection,

U=Wo(Q+αA)\mathbf{U}^{\top}=\mathbf{W}^{o}(\mathbf{Q}+\alpha \cdot \mathbf{A})

where αR\alpha \in \mathbb{R} is a learned scalar and W_oR3d×d\mathbf{W}\_{o} \in \mathbb{R}^{3 d \times d^{\prime}} is a parameter matrix. Q+αA\mathbf{Q}+\alpha \cdot \mathbf{A} is a residual connection which improves gradient propagation and stabilizes training. We initialize α\alpha to zero and as a result,

U=WoQ=(WoWq)X\mathbf{U}^{\top}=\mathbf{W}^{o} \mathbf{Q}=\left(\mathbf{W}^{o} \mathbf{W}^{q}\right) \mathbf{X}^{\top}

initially falls back to a linear transformation of the input XX skipping the attention transformation. Intuitively, skipping attention encourages leveraging recurrence to capture sequential patterns during early stage of training. As α|\alpha| grows, the attention mechanism can learn long-range dependencies for the model. In addition, WoWq\mathbf{W}^{o} \mathbf{W}^{q} can be interpreted as applying a matrix factorization trick with a small inner dimension d<dd^{\prime}<d, reducing the total number of parameters. The Figure compares the differences of SRU, SRU with this factorization trick (but without attention), and SRU++.

The last modification is adding layer normalization to each SRU++ layer. We apply normalization after the attention operation and before the matrix multiplication with Wo\mathbf{W}^{o}

U=Wolayernorm(Q+αA)\mathbf{U}^{\top}=\mathbf{W}^{o} \operatorname{layernorm}(\mathbf{Q}+\alpha \cdot \mathbf{A})

This implementation is post-layer normalization in which the normalization is added after the residual connection.