AICurious Logo

What is: ReLIC?

SourceRepresentation Learning via Invariant Causal Mechanisms
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

ReLIC, or Representation Learning via Invariant Causal Mechanisms, is a self-supervised learning objective that enforces invariant prediction of proxy targets across augmentations through an invariance regularizer which yields improved generalization guarantees.

We can write the objective as:

\underset{X}{\mathbb{E}} \underset{\sim\_{l k}, a\_{q \mathcal{A}}}{\mathbb{E}} \sum_{b \in\left\(a\_{l k}, a\_{q t}\right\)} \mathcal{L}\_{b}\left(Y^{R}, f(X)\right) \text { s.t. } K L\left(p^{d o\left(a\_{l k}\right)}\left(Y^{R} \mid f(X)\right), p^{d o\left(a\_{q t}\right)}\left(Y^{R} \mid f(X)\right)\right) \leq \rho

where L\mathcal{L} is the proxy task loss and KLK L is the Kullback-Leibler (KL) divergence. Note that any distance measure on distributions can be used in place of the KL divergence.

Concretely, as proxy task we associate to every datapoint x_ix\_{i} the label y_iR=iy\_{i}^{R}=i. This corresponds to the instance discrimination task, commonly used in contrastive learning. We take pairs of points (x_i,x_j)\left(x\_{i}, x\_{j}\right) to compute similarity scores and use pairs of augmentations a_lk=(a_l,a_k)a\_{l k}=\left(a\_{l}, a\_{k}\right) \in A×A\mathcal{A} \times \mathcal{A} to perform a style intervention. Given a batch of samples \left\(x\_{i}\right\)\_{i=1}^{N} \sim \mathcal{D}, we use

pdo(a_lk)(YR=jf(x_i))exp(ϕ(f(x_ia_l),h(x_ja_k))/τ)p^{d o\left(a\_{l k}\right)}\left(Y^{R}=j \mid f\left(x\_{i}\right)\right) \propto \exp \left(\phi\left(f\left(x\_{i}^{a\_{l}}\right), h\left(x\_{j}^{a\_{k}}\right)\right) / \tau\right)

with xax^{a} data augmented with aa and τ\tau a softmax temperature parameter. We encode ff using a neural network and choose hh to be related to ff, e.g. h=fh=f or as a network with an exponential moving average of the weights of ff (e.g. target networks). To compare representations we use the function ϕ(f(x_i),h(x_j))=g(f(x_i)),g(h(x_j))\phi\left(f\left(x\_{i}\right), h\left(x\_{j}\right)\right)=\left\langle g\left(f\left(x\_{i}\right)\right), g\left(h\left(x\_{j}\right)\right)\right\rangle where gg is a fully-connected neural network often called the critic.

Combining these pieces, we learn representations by minimizing the following objective over the full set of data x_iDx\_{i} \in \mathcal{D} and augmentations alkA×Aa_{l k} \in \mathcal{A} \times \mathcal{A}

i=1N_a_lklogexp(ϕ(f(x_ial),h(x_ia_k))/τ)_m=1Mexp(ϕ(f(x_ia_l),h(x_ma_k))/τ)+α_a_lk,a_qtKL(pdo(a_lk),pdo(a_qt))-\sum_{i=1}^{N} \sum\_{a\_{l k}} \log \frac{\exp \left(\phi\left(f\left(x\_{i}^{a_{l}}\right), h\left(x\_{i}^{a\_{k}}\right)\right) / \tau\right)}{\sum\_{m=1}^{M} \exp \left(\phi\left(f\left(x\_{i}^{a\_{l}}\right), h\left(x\_{m}^{a\_{k}}\right)\right) / \tau\right)}+\alpha \sum\_{a\_{l k}, a\_{q t}} K L\left(p^{d o\left(a\_{l k}\right)}, p^{d o\left(a\_{q t}\right)}\right)

with MM the number of points we use to construct the contrast set and α\alpha the weighting of the invariance penalty. The shorthand pdo(a)p^{d o(a)} is used for pdo(a)(YR=jf(x_i))p^{d o(a)}\left(Y^{R}=j \mid f\left(x\_{i}\right)\right). The Figure shows a schematic of the RELIC objective.