AICurious Logo

What is: Adaptive Masking?

SourceAdaptive Attention Span in Transformers
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

Adaptive Masking is a type of attention mechanism that allows a model to learn its own context size to attend over. For each head in Multi-Head Attention, a masking function is added to control for the span of the attention. A masking function is a non-increasing function that maps a distance to a value in [0,1]\left[0, 1\right]. Adaptive masking takes the following soft masking function m_zm\_{z} parametrized by a real value zz in [0,S]\left[0, S\right]:

m_z(x)=min[max[1R(R+zx),0],1]m\_{z}\left(x\right) = \min\left[\max\left[\frac{1}{R}\left(R+z-x\right), 0\right], 1\right]

where RR is a hyper-parameter that controls its softness. The shape of this piecewise function as a function of the distance. This soft masking function is inspired by Jernite et al. (2017). The attention weights from are then computed on the masked span:

a_tr=m_z(tr)exp(s_tr)t1_q=tSm_z(tq)exp(s_tq) a\_{tr} = \frac{m\_{z}\left(t-r\right)\exp\left(s\_{tr}\right)}{\sum^{t-1}\_{q=t-S}m\_{z}\left(t-q\right)\exp\left(s\_{tq}\right)}

A l_1\mathcal{l}\_{1} penalization is added on the parameters z_iz\_{i} for each attention head ii of the model to the loss function:

L=logP(w_1,,w_T)+λM_iz_iL = - \log{P}\left(w\_{1}, \dots, w\_{T}\right) + \frac{\lambda}{M}\sum\_{i}z\_{i}

where λ>0\lambda > 0 is the regularization hyperparameter, and MM is the number of heads in each layer. This formulation is differentiable in the parameters z_iz\_{i}, and learnt jointly with the rest of the model.