AICurious Logo

What is: Contrastive BERT?

SourceCoBERL: Contrastive BERT for Reinforcement Learning
Year2000
Data SourceCC BY-SA - https://paperswithcode.com

Contrastive BERT is a reinforcement learning agent that combines a new contrastive loss and a hybrid LSTM-transformer architecture to tackle the challenge of improving data efficiency for RL. It uses bidirectional masked prediction in combination with a generalization of recent contrastive methods to learn better representations for transformers in RL, without the need of hand engineered data augmentations.

For the architecture, a residual network is used to encode observations into embeddings Y_tY\_{t}. YtY_{t} is fed through a causally masked GTrXL transformer, which computes the predicted masked inputs X_tX\_{t} and passes those together with Y_tY\_{t} to a learnt gate. The output of the gate is passed through a single LSTM layer to produce the values that we use for computing the RL loss. A contrastive loss is computed using predicted masked inputs XtX_{t} and YtY_{t} as targets. For this, we do not use the causal mask of the Transformer.