AICurious Logo

What is: SimCLR?

SourceA Simple Framework for Contrastive Learning of Visual Representations
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

SimCLR is a framework for contrastive learning of visual representations. It learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss in the latent space. It consists of:

  • A stochastic data augmentation module that transforms any given data example randomly resulting in two correlated views of the same example, denoted x~_i\mathbf{\tilde{x}\_{i}} and x~_j\mathbf{\tilde{x}\_{j}}, which is considered a positive pair. SimCLR sequentially applies three simple augmentations: random cropping followed by resize back to the original size, random color distortions, and random Gaussian blur. The authors find random crop and color distortion is crucial to achieve good performance.

  • A neural network base encoder f()f\left(·\right) that extracts representation vectors from augmented data examples. The framework allows various choices of the network architecture without any constraints. The authors opt for simplicity and adopt ResNet to obtain h_i=f(x~_i)=ResNet(x~_i)h\_{i} = f\left(\mathbf{\tilde{x}}\_{i}\right) = \text{ResNet}\left(\mathbf{\tilde{x}}\_{i}\right) where h_iRdh\_{i} \in \mathbb{R}^{d} is the output after the average pooling layer.

  • A small neural network projection head g()g\left(·\right) that maps representations to the space where contrastive loss is applied. Authors use a MLP with one hidden layer to obtain z_i=g(h_i)=W(2)σ(W(1)h_i)z\_{i} = g\left(h\_{i}\right) = W^{(2)}\sigma\left(W^{(1)}h\_{i}\right) where σ\sigma is a ReLU nonlinearity. The authors find it beneficial to define the contrastive loss on z_iz\_{i}’s rather than h_ih\_{i}’s.

  • A contrastive loss function defined for a contrastive prediction task. Given a set {x~_k\mathbf{\tilde{x}}\_{k}} including a positive pair of examples x~_i\mathbf{\tilde{x}}\_{i} and x~_j\mathbf{\tilde{x}\_{j}} , the contrastive prediction task aims to identify x~_j\mathbf{\tilde{x}}\_{j} in {x~_k\mathbf{\tilde{x}}\_{k}}_ki\_{k\neq{i}} for a given x~_i\mathbf{\tilde{x}}\_{i}.

A minibatch of NN examples is randomly sampled and the contrastive prediction task is defined on pairs of augmented examples derived from the minibatch, resulting in 2N2N data points. Negative examples are not sampled explicitly. Instead, given a positive pair, the other 2(N1)2(N − 1) augmented examples within a minibatch are treated as negative examples. A NT-Xent (the normalized temperature-scaled cross entropy loss) loss function is used (see components).