AICurious Logo

What is: Split Attention?

SourceResNeSt: Split-Attention Networks
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

A Split Attention block enables attention across feature-map groups. As in ResNeXt blocks, the feature can be divided into several groups, and the number of feature-map groups is given by a cardinality hyperparameter KK. The resulting feature-map groups are called cardinal groups. Split Attention blocks introduce a new radix hyperparameter RR that indicates the number of splits within a cardinal group, so the total number of feature groups is G=KRG = KR. We may apply a series of transformations {F_1,F_2,F_G\mathcal{F}\_1, \mathcal{F}\_2, \cdots\mathcal{F}\_G} to each individual group, then the intermediate representation of each group is U_i=F_i(X)U\_i = \mathcal{F}\_i\left(X\right), for ii \in {1,2,G1, 2, \cdots{G}}.

A combined representation for each cardinal group can be obtained by fusing via an element-wise summation across multiple splits. The representation for kk-th cardinal group is U^k=j=R(k1)+1RkUj\hat{U}^k = \sum_{j=R(k-1)+1}^{R k} U_j , where U^kRH×W×C/K\hat{U}^k \in \mathbb{R}^{H\times W\times C/K} for k1,2,...Kk\in{1,2,...K}, and HH, WW and CC are the block output feature-map sizes. Global contextual information with embedded channel-wise statistics can be gathered with global average pooling across spatial dimensions skRC/Ks^k\in\mathbb{R}^{C/K}. Here the cc-th component is calculated as:

sk_c=1H×W_i=1H_j=1WU^k_c(i,j). s^k\_c = \frac{1}{H\times W} \sum\_{i=1}^H\sum\_{j=1}^W \hat{U}^k\_c(i, j).

A weighted fusion of the cardinal group representation VkRH×W×C/KV^k\in\mathbb{R}^{H\times W\times C/K} is aggregated using channel-wise soft attention, where each feature-map channel is produced using a weighted combination over splits. The cc-th channel is calculated as:

Vck=i=1Raik(c)UR(k1)+i, V^k_c=\sum_{i=1}^R a^k_i(c) U_{R(k-1)+i} ,

where aik(c)a_i^k(c) denotes a (soft) assignment weight given by:

aik(c)={exp(Gic(sk))j=0Rexp(Gjc(sk))if R>1,11+exp(Gic(sk))if R=1,a_i^k(c) = \begin{cases} \frac{exp(\mathcal{G}^c_i(s^k))}{\sum_{j=0}^R exp(\mathcal{G}^c_j(s^k))} & \quad\textrm{if } R>1, \\ \frac{1}{1+exp(-\mathcal{G}^c_i(s^k))} & \quad\textrm{if } R=1,\\ \end{cases}

and mapping Gic\mathcal{G}_i^c determines the weight of each split for the cc-th channel based on the global context representation sks^k.