Sign In
Free Sign Up
  • English
  • Español
  • 简体中文
  • Deutsch
  • 日本語
Sign In
Free Sign Up
  • English
  • Español
  • 简体中文
  • Deutsch
  • 日本語

An In-Depth Guide to Contrastive Learning: Techniques, Models, and Applications

Traditionally, machine learning (ML) (opens new window) can be classified broadly into two types: supervised and unsupervised learning (opens new window). In supervised learning, we have some data with labels and these labels are used to train the model. For example, using the (labelled) images of different objects to train an image classifier.

Unsupervised learning, on the other hand, doesn’t require any labelling and here we explore the patterns without any prior information. Inevitably, it's difficult but appealing too. With tons of data available (and new data produced every day), it's always easy to get the data, though getting it labelled requires a lot of time and money.

To harness the heaps of data available around us, there is an advanced learning strategy, known as self-supervised learning. In self-supervised learning, we take the unlabelled data and mimic supervised learning.

# Self-Supervised Learning (SSL)

In self-supervised learning, we partition the data into positive and negative samples - similar to binary (supervised) classification - by treating the object under consideration as a positive example and all the other samples as negative.

Self-supervised learning methods learn joint embeddings and can be broadly classified into two types:

  • Contrastive methods
  • Non-contrastive methods

In contrastive methods, we take different samples of the same data, like different views of the same image and try to maximize their similarity scores, while trying to minimise them for the other samples/images. Non-contrastive methods, on the other hand, don’t account for the negative samples. Some of the famous SSL methods like BYOL or DINO are good examples of non-contrastive methods, while SimCLR and MoCo use negative examples too and are good examples of contrastive methods.

# Contrastive Learning (CL)

Contrastive learning centres around a simple concept of choosing a representation that maximizes the similarities between positive data pairs, while minimizing for negative pairs. For example, I input an image of a mango and now its goal should be to maximize the similarity between mango images while minimizing it for some other images.

The simplest setting is to consider each data point (when considering it) as a positive data point while considering everyone else as a negative point. Let’s suppose the point is , then the pair should be considered positive and should be negative (where denotes any other data point). Here can represent the mango’s image and its encoder’s representation (and similarly for other image representations).

# Training

For each pair of data points , we calculate the similarity score. And use this score to train our model by minimizing this score for the negative data points and maximizing when . Like other optimization problems, it is also done by choosing a suitable loss function, like cross-entropy loss.

To make the training more practical, we perform it in batches and let's say we have a batch of 32 images, then 1 image (itself) needs to have a maximized similarity score, while the other 31 will have to have as much minimized similarity score as possible.

However, having a single positive vs a number of negative examples makes it quite hard to learn discriminative features, so we employ some smart techniques like data augmentation to make multiple copies of the same data point.

# CL Examples

Contrastive Learning has been there for some time. In the last 10 years, it has come a long way from normal CNN-based (opens new window) image discrimination (which was just happy to outperform SIFT) to CLIP. Some of the modern CL algorithms are:

  • Contrastive Predictive Coding, CPC
  • A simple framework for Contrastive Learning of (visual) Representations, SimCLR
  • Momentum Contrast, MoCo
  • Contrastive Language-Image Pretraining, CLIP

Let’s go through them briefly to get a better idea of how contrastive learning is being used practically.

# Contrastive Predictive Coding

Inspired by the classical data compression technique of predictive coding and its adaptation in neuroscience, CPC (Contrastive Predictive Coding) (opens new window) tries to focus on the high-level information in the data, while ignoring the low-level/noise.

CPC, in a nutshell, works as follows:

  • High-dimensional data is compressed into a suitable latent embedding space. This compression makes it easier to model the data and make predictions accordingly.
  • Predictions are made in the chosen embedding space.
  • The model is trained using the Noise-Contrastive Estimation (NCE) loss function.

CPC

# SimCLR

A Simple framework for Contrastive Learning of visual Representations, SimCLR is an advanced contrastive learning technique for computer vision. Without needing any pre-augmentation or specialized architecture, SimCLR works as follows:

  • An image is chosen randomly and its views (two in the original implementation) are generated using different augmentation techniques like random cropping, random colour distortion or Gaussian blurring.
  • Image representation/embedding is computed using a ResNet-based CNN.
  • This representation is further transformed into a (non-linear) projection using an MLP.
  • Both CNN and MLP are trained to minimize the contrastive loss function.

Throughout, we have been talking about the need for unsupervised learning, but we have some labelled data at our disposal too. In the end, if we fine-tune the CNN on some labelled images, it helps increase its performance and generalization on diverse other (downstream) tasks.

Some insights on how Contrastive Learning works

Not only did SimCLR introduce a new model with very good performance (you can check its paper for a detailed results analysis), but its authors also gave some new insights which can be useful for almost any contrastive learning method. So I found them worthy of sharing here:

  • A combination of augmentation techniques is critical: Random cropping and colour distortion didn’t give standout results when used individually, but when used in conjunction, they give the best results.

  • Nonlinear projection is important: The complex nature of neural networks and contrastive loss function means it's hard to understand what’s going on behind the scenes, but empirically it is clear that the non-linear projection (by MLP) is useful as it increases the performance up to 10%. This fact will be independently observed in the MoCov2’s publication too, as we will shortly see.

  • Scaling up improves performance: While some of the observations are specific to SimCLR, they are mainly generic across contrastive learning. Increasing the model’s capacity (either width or depth), increasing the batch size or even the number of epochs all lead to an increase in performance.

# Momentum Contrast (MoCo)

Momentum Contrast (MoCo) takes an alternative view of contrastive learning as a dictionary lookup. This interesting point of view, having some similarities with the transformer models, works as follows:

  • Data augmentation is applied to produce two copies, ​ and ​.
  • Query encoder (one on the left in the image below) takes the ​ and generates embeddings.
  • The momentum encoder takes the other augmented copy, xk​ and dynamically generates a dictionary of keys, .

To make it relevant, it is implemented as a queue, which takes recent minibatches data in it and dequeues the earliest one (once minibatches are greater than ). Since it is implemented as a momentum-based moving average of the key encoder, hence this name.

  • Encoded query, matches to the dictionary of keys, and contrastive loss (infoNCE) is calculated.
  • Both encoders are trained together to minimize this contrastive loss.

Moco

If you are familiar with transformers, you can see that InfoNCE loss is quite similar to the way we calculate attention in the transformers.

# CLIP

Introduced in 2021, CLIP ups the ante by combining both images and their captions. While it doesn’t work in a momentum-based averaging way, it also works with two encoders, one for text and the other one for images. Here’s its brief workflow:

  • The Image is entered into the image encoder and its caption is entered in the text encoder.
  • Image encoder, based on a ViT (Vision Transformer) gets the image embedding, while text encoder tokenizes the caption to get the text features. These features are collated as a pair in the embedding space.
  • Both text and image encoders are trained in a way to maximize the distance of any given pair of with the others.
  • While testing, we provide a dictionary of captions (not to be confused with the dynamic dictionary in MoCo) and the desired image. Based on the image, it returns the caption having the highest probability.

CLIP

Note: For more details about CLIP, please read this (opens new window).

Boost Your AI App Efficiency now
Sign up for free to benefit from 150+ QPS with 5,000,000 vectors
Free Trial
Explore our product

# Code Example

Let's wrap things up with a code example to complete the picture. Here, I will be using the official implementation of MoCo (opens new window) by Meta Research. This code mainly centres around the Moco class.

# Constructor

MoCo class’s constructor initializes the attributes like K, m and T. As we can see, it uses the default values of feature dimension (dim) as 128, queue size (K) as 16-bits (65,536), while momentum co-efficient,μ is 0.999 (quite slow moving average). Softmax temperature τ, as already specified in the paper is 0.07.

Also, we can see the implementation of MLP which we saw first in the SimCLR, and later on in the MoCov2 (this implementation includes both versions).

def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
            )

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

# Momentum Encoder

The core part is simple yet so effective. In the momentum encoder, we are simply implementing this equation for momentum-based average:

μμ

@torch.no_grad()
    def _momentum_update_key_encoder(self):
    
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

# Queue Management

Finally, the queue is updated to remove the older batch data. It works as follows:

  • We gather the keys first
  • Queue’s pointer points to the keys transpose. Now, why it is done precisely for the columns and not rows is something I am curious about.
  • In the end, the pointer moves to the next batch (modulus here is for the circular queue) and it’s reflected in the attribute (self.queue_ptr) too.
 def _dequeue_and_enqueue(self, keys):
        
        keys = concat_all_gather(keys)
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0

        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K 
        self.queue_ptr[0] = ptr

# Training

In the training pass, we combine this all as:

  • Computing query features
  • Computing key features (using the momentum update)
  • Computing the logits (both positive and negative)
  • Dequeue and enqueue using the function we just discussed above
def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels

By the way, all these comments are written by the authors themselves and show that not only did they write a clear paper, but written a good, easy-to-understand code too.

Join Our Newsletter

# Contrastive Learning and Vector Databases

Since contrastive learning centres around the input data embeddings and their distance, they are highly relevant to vector databases. Using a trained CL model, we can simply apply a metric like Euclidean distance, Cosine similarity or distance to make predictions for the new samples.

As a result, we can have countless applications of contrastive learning in conjunction with vector databases, for example:

  • Zero-shot recognition (can have a number of applications itself) using CLIP
  • Recommendation system for an online user - based on the user’s history, it will suggest the products having the closest proximity to those embeddings.
  • Document retrieval - retrieve the document from the database which is closest to the user’s query embedding
  • Anomaly detection - whether a new transaction by this credit card is normal or fraudulent can be checked using the previous transaction embeddings and finding the similarity (or dissimilarity) between this transaction and the transaction history’s “region” in the embedding space.

# References

  1. Dosovitskiy, et al., Discriminative Unsupervised Feature Learning with Exemplar Convolutional Neural Networks, IEEE PAMI, 2016.
  2. Hjelm, et al., Learning Deep Representations by Mutual Information Estimation and Maximization, ICLR, 2019.
  3. Oord, et al., Representation Learning with Contrastive Predictive Coding, arXiv 2018.
  4. Chen, et al., A Simple Framework for Contrastive Learning of Visual Representations, ICML 2020.
  5. Radford, et al., Learning Transferable Visual Models From Natural Language Supervision, arXiv 2020
  6. He, et al., Momentum Contrast for Unsupervised Visual Representation Learning, arXiv 2020
  7. He, et al., Improved Baselines with Momentum Contrastive Learning, arXiv 2020
Keep Reading
images
Prompt Engineering vs fine-tuning vs RAG

Since the release of Large Language Models (LLMs) and advanced chat models, various techniques have been used to extract the desired outputs from these AI systems. Some of these methods involve alteri ...

Start building your Al projects with MyScale today

Free Trial
Contact Us