AICurious Logo

What is: Semantic Cross Attention?

SourceSCAM! Transferring humans between images with Semantic Cross Attention Modulation
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

Semantic Cross Attention (SCA) is based on cross attention, which we restrict with respect to a semantic mask.

The goal of SCA is two-fold depending on what is the query and what is the key. Either it allows to give the feature map information from a semantically restricted set of latents or, respectively, it allows a set of latents to retrieve information in a semantically restricted region of the feature map.

SCA is defined as:

\begin{equation} \text{SCA}(I_{1}, I_{2}, I_{3}) = \sigma\left(\frac{QK^T\odot I_{3} +\tau \left(1-I_{3}\right)}{\sqrt{d_{in}}}\right)V \quad , \end{equation}

where I1,I2,I3I_{1},I_{2},I_{3} the inputs, with I1I_{1} attending I2I_{2}, and I3I_{3} the mask that forces tokens from I1I_1 to attend only specific tokens from I2I_2. The attention values requiring masking are filled with -\infty before the softmax. (In practice τ=109\tau{=}-10^9), Q=WQI1Q {=} W_QI_{1}, K=WKI2K {=} W_KI_{2} and V=WVI2V {=} W_VI_{2} the queries, keys and values, and dind_{in} the internal attention dimension. σ(.)\sigma(.) is the softmax operation.

Let XRn×CX\in\mathbb{R}^{n\times C} be the feature map with n the number of pixels, and C the number of channels. Let ZRm×dZ\in\mathbb{R}^{m\times d} be a set of mm latents of dimension dd and ss the number of semantic labels. Each semantic label is attributed kk latents, such that m=k×sm=k\times s. Each semantic label mask is assigned kk copies in S{0;1}n×mS{\in}\{0;1\}^{n \times m}.

We can differentiate 3 types of SCA:

(a) SCA with pixels XX attending latents ZZ: SCA(X,Z,S)\text{SCA}(X, Z, S), where WQRn×dinW_{Q} {\in} \mathbb{R}^{n\times d_{in}} and WK,WVRm×dinW_{K}, W_{V} {\in} \mathbb{R}^{m\times d_{in}}. The idea is to force the pixels from a semantic region to attend latents that are associated with the same label.

(b) SCA with latents ZZ attending pixels XX: SCA(Z,X,S)\text{SCA}(Z, X, S), where WQRm×dinW_{Q}{\in} \mathbb{R}^{m\times d_{in}}, WK,WVRn×dinW_{K}, W_{V} {\in} \mathbb{R}^{n\times d_{in}}. The idea is to semantically mask attention values to enforce latents to attend semantically corresponding pixels.

(c) SCA with latents ZZ attending themselves: SCA(Z,Z,M)\text{SCA}(Z, Z, M), where WQ,WK,WVRn×dinW_{Q}, W_{K}, W_{V} {\in} \mathbb{R}^{n\times d_{in}}. We denote MNm×mM\in\mathbb{N}^{m\times m} this mask, with Mlatents(i,j)=1M_{\text{latents}}(i,j) {=} 1 if the semantic label of latent ii is the same as the one of latent jj; 00 otherwise. The idea is to let the latents only attend latents that share the same semantic label.