AICurious Logo

What is: SRU?

SourceSimple Recurrent Units for Highly Parallelizable Recurrence
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

SRU, or Simple Recurrent Unit, is a recurrent neural unit with a light form of recurrence. SRU exhibits the same level of parallelism as convolution and feed-forward nets. This is achieved by balancing sequential dependence and independence: while the state computation of SRU is time-dependent, each state dimension is independent. This simplification enables CUDA-level optimizations that parallelize the computation across hidden dimensions and time steps, effectively using the full capacity of modern GPUs.

SRU also replaces the use of convolutions (i.e., ngram filters), as in QRNN and KNN, with more recurrent connections. This retains modeling capacity, while using less computation (and hyper-parameters). Additionally, SRU improves the training of deep recurrent models by employing highway connections and a parameter initialization scheme tailored for gradient propagation in deep architectures.

A single layer of SRU involves the following computation:

f_t=σ(W_fx_t+v_fc_t1+b_f)\mathbf{f}\_{t} =\sigma\left(\mathbf{W}\_{f} \mathbf{x}\_{t}+\mathbf{v}\_{f} \odot \mathbf{c}\_{t-1}+\mathbf{b}\_{f}\right)
c_t=f_tc_t1+(1f_t)(Wx_t)\mathbf{c}\_{t} =\mathbf{f}\_{t} \odot \mathbf{c}\_{t-1}+\left(1-\mathbf{f}\_{t}\right) \odot\left(\mathbf{W} \mathbf{x}\_{t}\right) \\
r_t=σ(W_rx_t+v_rc_t1+b_r)\mathbf{r}\_{t} =\sigma\left(\mathbf{W}\_{r} \mathbf{x}\_{t}+\mathbf{v}\_{r} \odot \mathbf{c}\_{t-1}+\mathbf{b}\_{r}\right) \\
h_t=r_tc_t+(1r_t)x_t\mathbf{h}\_{t} =\mathbf{r}\_{t} \odot \mathbf{c}\_{t}+\left(1-\mathbf{r}\_{t}\right) \odot \mathbf{x}\_{t}

where W,W_f\mathbf{W}, \mathbf{W}\_{f} and W_r\mathbf{W}\_{r} are parameter matrices and v_f,v_r,b_f\mathbf{v}\_{f}, \mathbf{v}\_{r}, \mathbf{b}\_{f} and bv\mathbf{b}_{v} are parameter vectors to be learnt during training. The complete architecture decomposes to two sub-components: a light recurrence and a highway network,

The light recurrence component successively reads the input vectors x_t\mathbf{x}\_{t} and computes the sequence of states c_t\mathbf{c}\_{t} capturing sequential information. The computation resembles other recurrent networks such as LSTM, GRU and RAN. Specifically, a forget gate f_t\mathbf{f}\_{t} controls the information flow and the state vector c_t\mathbf{c}\_{t} is determined by adaptively averaging the previous state c_t1\mathbf{c}\_{t-1} and the current observation Wx+\mathbf{W} \mathbf{x}_{+}according to f_t\mathbf{f}\_{t}.