AICurious Logo

What is: Spatial-Reduction Attention?

SourcePyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

Spatial-Reduction Attention, or SRA, is a multi-head attention module used in the Pyramid Vision Transformer architecture which reduces the spatial scale of the key KK and value VV before the attention operation. This reduces the computational/memory overhead. Details of the SRA in the stage ii can be formulated as follows:

\text{SRA}(Q, K, V)=\text { Concat }\left(\operatorname{head}\_{0}, \ldots \text { head }\_{N\_{i}}\right) W^{O} $$ $$\text{ head}\_{j}=\text { Attention }\left(Q W\_{j}^{Q}, \operatorname{SR}(K) W\_{j}^{K}, \operatorname{SR}(V) W\_{j}^{V}\right)

where Concat ()(\cdot) is the concatenation operation. W_jQRC_i×d_head W\_{j}^{Q} \in \mathbb{R}^{C\_{i} \times d\_{\text {head }}}, W_jKRC_i×d_head W\_{j}^{K} \in \mathbb{R}^{C\_{i} \times d\_{\text {head }}}, W_jVRC_i×d_head W\_{j}^{V} \in \mathbb{R}^{C\_{i} \times d\_{\text {head }}}, and WORC_i×C_iW^{O} \in \mathbb{R}^{C\_{i} \times C\_{i}} are linear projection parameters. N_iN\_{i} is the head number of the attention layer in Stage ii. Therefore, the dimension of each head (i.e. d_head )\left.d\_{\text {head }}\right) is equal to C_iN_i.SR()\frac{C\_{i}}{N\_{i}} . \text{SR}(\cdot) is the operation for reducing the spatial dimension of the input sequence (KK or VV ), which is written as:

SR(x)=Norm(Reshape(x,R_i)WS)\text{SR}(\mathbf{x})=\text{Norm}\left(\operatorname{Reshape}\left(\mathbf{x}, R\_{i}\right) W^{S}\right)

Here, xR(H_iW_i)×C_i\mathbf{x} \in \mathbb{R}^{\left(H\_{i} W\_{i}\right) \times C\_{i}} represents a input sequence, and R_iR\_{i} denotes the reduction ratio of the attention layers in Stage i.i . Reshape (x,R_i)\left(\mathbf{x}, R\_{i}\right) is an operation of reshaping the input sequence x\mathbf{x} to a sequence of size H_iW_iR_i2×(R_i2C_i)\frac{H\_{i} W\_{i}}{R\_{i}^{2}} \times\left(R\_{i}^{2} C\_{i}\right). W_SR(R_i2C_i)×C_iW\_{S} \in \mathbb{R}^{\left(R\_{i}^{2} C\_{i}\right) \times C\_{i}} is a linear projection that reduces the dimension of the input sequence to C_iC\_{i}. Norm()\text{Norm}(\cdot) refers to layer normalization.