AICurious Logo

What is: DExTra?

SourceDeLighT: Deep and Light-weight Transformer
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

DExTra, or Deep and Light-weight Expand-reduce Transformation, is a light-weight expand-reduce transformation that enables learning wider representations efficiently.

DExTra maps a d_md\_{m} dimensional input vector into a high dimensional space (expansion) and then reduces it down to a d_od\_{o} dimensional output vector (reduction) using NN layers of group transformations. During these expansion and reduction phases, DExTra uses group linear transformations because they learn local representations by deriving the output from a specific part of the input and are more efficient than linear transformations. To learn global representations, DExTra shares information between different groups in the group linear transformation using feature shuffling

Formally, the DExTra transformation is controlled by five configuration parameters: (1) depth NN, (2) width multiplier m_wm\_{w}, (3) input dimension d_md\_{m}, (4) output dimension d_od\_{o}, and (5) maximum groups g_maxg\_{max} in a group linear transformation. In the expansion phase, DExTra projects the d_md\_{m}-dimensional input to a high-dimensional space, d_max=m_wd_md\_{max} = m\_{w}d\_{m}, linearly using ceil(N2)\text{ceil}\left(\frac{N}{2}\right) layers. In the reduction phase, DExTra projects the d_maxd\_{max}-dimensional vector to a d_od\_{o}-dimensional space using the remaining Nceil(N2)N -\text{ceil}\left(\frac{N}{2}\right) layers. Mathematically, we define the output YY at each layer ll as:

Y_l=F(X,Wl,bl,gl) if l=1\mathbf{Y}\_{l} = \mathcal{F}\left(\mathbf{X}, \mathbf{W}^{l}, \mathbf{b}^{l}, g^{l}\right) \text{ if } l=1 Y_l=F(H(X,Yl1),Wl,bl,gl) Otherwise \mathbf{Y}\_{l} = \mathcal{F}\left(\mathcal{H}\left(\mathbf{X}, \mathbf{Y}^{l-1}\right), \mathbf{W}^{l}, \mathbf{b}^{l}, g^{l}\right) \text{ Otherwise }

where the number of groups at each layer ll are computed as:

gl=min(2l1,g_max),1lceil(N/2)g^{l} = \text{min}\left(2^{l-1}, g\_{max}\right), 1 \leq l \leq \text{ceil}\left(N/2\right) gNl,Otherwise g^{N-l}, \text{Otherwise}

In the above equations, F\mathcal{F} is a group linear transformation function. The function F\mathcal{F} takes the input (X or H(X,Yl1))\left(\mathbf{X} \text{ or } \mathcal{H}\left(\mathbf{X}, \mathbf{Y}^{l-1}\right) \right), splits it into glg^{l} groups, and then applies a linear transformation with learnable parameters Wl\mathbf{W}^{l} and bias bl\mathbf{b}^{l} to each group independently. The outputs of each group are then concatenated to produce the final output Yl\mathbf{Y}^{l}. The function H\mathcal{H} first shuffles the output of each group in Yl1\mathbf{Y}^{l−1} and then combines it with the input X\mathbf{X} using an input mixer connection.

In the authors' experiments, they use g_max=ceil(d_m32)g\_{max} = \text{ceil}\left(\frac{d\_{m}}{32}\right) so that each group has at least 32 input elements. Note that (i) group linear transformations reduce to linear transformations when gl=1g^{l} = 1, and (ii) DExTra is equivalent to a multi-layer perceptron when g_max=1g\_{max} = 1.