PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Overview

PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

DOI

Blog post with full documentation: Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Image of SimCLR Arch

See also PyTorch Implementation for BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning.

Installation

$ conda env create --name simclr --file env.yml
$ conda activate simclr
$ python run.py

Config file

Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the run.py file.

$ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100 

If you want to run it on CPU (for debugging purposes) use the --disable-cuda option.

For 16-bit precision GPU training, there NO need to to install NVIDIA apex. Just use the --fp16_precision flag and this implementation will use Pytorch built in AMP training.

Feature Evaluation

Feature evaluation is done using a linear model protocol.

First, we learned features using SimCLR on the STL10 unsupervised set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the STL10 train set and evaluated on the STL10 test set.

Check the Open In Colab notebook for reproducibility.

Note that SimCLR benefits from longer training.

Linear Classification Dataset Feature Extractor Architecture Feature dimensionality Projection Head dimensionality Epochs Top1 %
Logistic Regression (Adam) STL10 SimCLR ResNet-18 512 128 100 74.45
Logistic Regression (Adam) CIFAR10 SimCLR ResNet-18 512 128 100 69.82
Logistic Regression (Adam) STL10 SimCLR ResNet-50 2048 128 50 70.075
Comments
  • A question about the

    A question about the "labels"

    Hi! I have a question about the definition of "labels" in the script "simclr.py".

    On line 54 of "simclr.py", the authors defined:

    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

    So all the entries of "labels" are all zeros. But I think according to the paper, there should be an entry as 1 for the positive pair?

    Thanks in advance for your reply!

    opened by kekehia123 6
  • size of tensors in cosine_simiarity function

    size of tensors in cosine_simiarity function

    Hi , I'm trying to understand the code in : loss/nt_xent.py

    we are sending "representations" on both arguments

        def forward(self, zis, zjs):
            representations = torch.cat([zjs, zis], dim=0)
            similarity_matrix = self.similarity_function(representations, representations)
    

    But when receiving it in cosine_similarity func somehow the sizes are: (N, 1, C) and y shape: (1, 2N, C), how can it be double if you sent the same argument

        def _cosine_simililarity(self, x, y):
            # x shape: (N, 1, C)
            # y shape: (1, 2N, C)
            # v shape: (N, 2N)
            v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
            return v
    

    Thanks for your help.

    opened by BattashB 5
  • How do i train the SimCLR model with my local dataset?

    How do i train the SimCLR model with my local dataset?

    Dear researcher, Thank you for the open-source code you provided, it is of great help to me for understanding contrastive learning. But I still have some confusion when training the SimCLR model with my local dataset, could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

    opened by bestalllen 4
  • Question about CE Loss

    Question about CE Loss

    Hello,

    Thanks for sharing the code, nice implementation.

    The way you calculate the loss by using a mask is quite brilliant. But I have a question.

    logits = torch.cat((positives, negatives), dim=1) So if I'm not wrong, the first column of logits is positive and the rest are negatives.

    labels = torch.zeros(2 * self.batch_size).to(self.device).long() But your labels are all zeros, which means no matter positive or negative, the similarity should low.

    So I wonder is the first column of labels supposed to be 1 instead of 0.

    Thanks for your help.

    opened by WShijun1991 4
  • Issue with batch-size

    Issue with batch-size

    In function info_nce_loss, the line 28, creates labels based on batch_size and on other side we have STL10 dataset which has 100,000 images which is divisible by batch_size of 32 and having batch_size like 128 or 64 gives a remainder of 32.

    Having batch_size != 32, causes error in line 42, because the similarity matrix will based on features and labels will be based on batch size.

    For instance, if the batch size = 128, the remaining images in the dataset in the last iter of data_loader is 32. Since we create two variant of each image we'll have 64 images. Now we have 128 x 2 = 256 labels from line 28, and we'll have similarity matrix of (64 x 128, 128 x 64) => (64 x 64) but with mask (256 x 256) causing "dimension mismatch"

    Solution: Change Line 28 as below

    labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(self.args.n_views)], dim=0)

    image

    opened by Mayurji 3
  • 'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    'CosineAnnealingLR' never works with the wrong position of 'scheduler.step()'

    Considering the setting in 'scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)',I think 'scheduler.step()' should be called every step in 'for (xis,xjs),_ in train_loader'. Otherwise,lr will nerver change until 'len(train_loader)' epochs but not steps

    opened by GuohongLi 3
  • Is it something wrong with the training model for CIFAR-10 experiments?

    Is it something wrong with the training model for CIFAR-10 experiments?

    Hi,

    I find that the ResNet20 model for CIFAR-10 experiments is not fully correct. The head conv structure should be modified (stride=1 and no pooling,) because the image size of CIFAR-10 is very small.

    opened by timqqt 2
  • GPU utilization rate is low

    GPU utilization rate is low

    Hi, thanks for the code!

    When I tried to run it on single GPU (v-100), the utilizaiton rate is very low (~0-10%) even if I increase num_worker. Would you know why this happens and how to solve it? Thanks!

    opened by LiJunnan1992 2
  • Why cos_sim after L2 norm?

    Why cos_sim after L2 norm?

    Hi, This code is really useful for me. Thanks! But I got a question about the NT-Xent loss. I noticed that you use L2 norm on z and then use cos_similarity after that. But cos_similarity already contain the function of l2 norm. Why use L2 norm first?

    opened by BoPang1996 2
  • NT_Xent Loss function: all negatives are not being used?

    NT_Xent Loss function: all negatives are not being used?

    Hi @sthalles , Thank you for sharing your code!

    Pl correct me if I am wrong: I see that in line loss/nt_xent.py line 57 (below) you are not computing contrastive loss for all negative pairs as you are reshaping total negatives in 2D array i.e. only a part of negative pairs are being used for a single positive pair, right? :

    _negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)_
    _logits = torch.cat((positives, negatives), dim=1)_
    

    Hope to hear from you soon.

    -Ishan

    opened by DAVEISHAN 2
  • Validation Loss calculation

    Validation Loss calculation

    First of all, thank you for your great work!

    Method _validate in simclr.py will raise ZeroDivisionError at line 148 if the validation data loader performs only one iteration (since counter starts from 0).

    opened by alessiamarcolini 2
  • evaluation code batch_size & validation process

    evaluation code batch_size & validation process

    I'm really appreciated about your good work :) I left a question because I got confused while studying through your great code.

    First, I wonder why you used "batch_size=batch_size*2" differently from train_loader in the test_loader part of the file "mini_batch_logistic_regression_valuator.ipynb". Is it related to creating 2 views when doing data augmentation?

    Also, in the last cell of this file, I'm confused whether the second "for" (of the two "for") in the large epoch "for" statement corresponds to the test process or the validation process. I thought it was a test process, because loss update, backpropagation, optimization, etc. were done only in the first "for", and the second yield only accuracy, but is that right? Or I'm confused if the second "for" is a validating process because the first "for" and the second "for" are going together in the entire epoch processing.

    opened by YejinS 0
  • Review Training | Fine-Tune | Test details

    Review Training | Fine-Tune | Test details

    Hi, I just want to check all the experiments details and make sure I didn't miss any part(?

    1. Training Phase : use SimCLR (two encoder branches) to train on ImageNet for 1000 epochs to get a init pretrained weights.
    2. Fine-Tuned : load the init pretrained weights on the resnet18(50/101/...) with freezed parameters and concate with a linear classifier, and train the classifier with CIFAR10/STL10 training dataset for 100 epochs.
    3. Test Phase : freeze all the encoder, classifier parameters, and test on the CIFAR10/STL10 testing dataset.

    Is this the way how you get the top1 acc in the README?

    opened by Howeng98 0
  • Confusion matrix

    Confusion matrix

    Does anyone know how to add the confusion matrix in this code? After I added it according to the online one, something went wrong. I don't know what went wrong in my code.I can't solve it. please help help me! Thanks. def confusion_matrix(output, labels, conf_matrix):

    preds = torch.argmax(output, dim=-1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix
    
    opened by here101 0
  • batch size affect

    batch size affect

    Hi, I'm trying to experiment with CIFAR-10 with the default hyper-params, and it seems to yield a better score when using smaller batch size (e.g. 72% with batch size 256 yet 78% with batch size 128). Anyone in the same situation, here?

    opened by VietHoang1512 1
  •  ModuleNotFoundError: No module named 'torch.cuda'

    ModuleNotFoundError: No module named 'torch.cuda'

    I am using pythion 3.7 on Win10, Anaconda Jupyter. I have successfully installed torch-1.10.0+cu113 torchaudio-0.10.0+cu113 torchvision-0.11.1+cu113. When trying to import torch , I get ModuleNotFoundError: No module named 'torch.cuda' Detailed error:

    ModuleNotFoundError                       Traceback (most recent call last)
    <ipython-input-1-bfd2c657fa76> in <module>
          1 import numpy as np
          2 import pandas as pd
    ----> 3 import torch
          4 import torch.nn as nn
          5 from sklearn.model_selection import train_test_split
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\__init__.py in <module>
        603 
        604 # Shared memory manager needs to know the exact location of manager executable
    --> 605 _C._initExtension(manager_path())
        606 del manager_path
        607 
    
    ModuleNotFoundError: No module named 'torch.cuda'
    

    I found posts for similar error No module named 'torch.cuda.amp'. However, any of the suggested solutions worked. Please advise.

    opened by m-bor 0
Releases(v1.0.1)
DeconvNet : Learning Deconvolution Network for Semantic Segmentation

DeconvNet: Learning Deconvolution Network for Semantic Segmentation Created by Hyeonwoo Noh, Seunghoon Hong and Bohyung Han at POSTECH Acknowledgement

Hyeonwoo Noh 325 Oct 20, 2022
Train a state-of-the-art yolov3 object detector from scratch!

TrainYourOwnYOLO: Building a Custom Object Detector from Scratch This repo let's you train a custom image detector using the state-of-the-art YOLOv3 c

AntonMu 616 Jan 08, 2023
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
PyTorch implementation of HDN(Homography Decomposition Networks) for planar object tracking

Homography Decomposition Networks for Planar Object Tracking This project is the offical PyTorch implementation of HDN(Homography Decomposition Networ

CaptainHook 48 Dec 15, 2022
Simple image captioning model - CLIP prefix captioning.

Simple image captioning model - CLIP prefix captioning.

688 Jan 04, 2023
CLIP+FFT text-to-image

Aphantasia This is a text-to-image tool, part of the artwork of the same name. Based on CLIP model, with FFT parameterizer from Lucent library as a ge

vadim epstein 690 Jan 02, 2023
[peer review] An Arbitrary Scale Super-Resolution Approach for 3D MR Images using Implicit Neural Representation

ArSSR This repository is the pytorch implementation of our manuscript "An Arbitrary Scale Super-Resolution Approach for 3-Dimensional Magnetic Resonan

Qing Wu 19 Dec 12, 2022
Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation

Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation This repository contains code and data f

Zoey Liu 0 Jan 07, 2022
Leaderboard, taxonomy, and curated list of few-shot object detection papers.

Leaderboard, taxonomy, and curated list of few-shot object detection papers.

Gabriel Huang 70 Jan 07, 2023
[ICLR2021oral] Rethinking Architecture Selection in Differentiable NAS

DARTS-PT Code accompanying the paper ICLR'2021: Rethinking Architecture Selection in Differentiable NAS Ruochen Wang, Minhao Cheng, Xiangning Chen, Xi

Ruochen Wang 86 Dec 27, 2022
Projects for AI/ML and IoT integration for games and other presented at re:Invent 2021.

Playground4AWS Projects for AI/ML and IoT integration for games and other presented at re:Invent 2021. Architecture Minecraft and Lamps This project i

Vinicius Senger 5 Nov 30, 2022
[CVPR'21] Multi-Modal Fusion Transformer for End-to-End Autonomous Driving

TransFuser This repository contains the code for the CVPR 2021 paper Multi-Modal Fusion Transformer for End-to-End Autonomous Driving. If you find our

695 Jan 05, 2023
BED: A Real-Time Object Detection System for Edge Devices

BED: A Real-Time Object Detection System for Edge Devices About this project Thi

Data Analytics Lab at Texas A&M University 44 Nov 18, 2022
Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.

Tensor2Tensor Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and ac

12.9k Jan 09, 2023
🧮 Matrix Factorization for Collaborative Filtering is just Solving an Adjoint Latent Dirichlet Allocation Model after All

Accompanying source code to the paper "Matrix Factorization for Collaborative Filtering is just Solving an Adjoint Latent Dirichlet Allocation Model A

Florian Wilhelm 39 Dec 03, 2022
PyTorch implementation of Neural Dual Contouring.

NDC PyTorch implementation of Neural Dual Contouring. Citation We are still writing the paper while adding more improvements and applications. If you

Zhiqin Chen 140 Dec 26, 2022
A mini lib that implements several useful functions binding to PyTorch in C++.

Torch-gather A mini library that implements several useful functions binding to PyTorch in C++. What does gather do? Why do we need it? When dealing w

maxwellzh 8 Sep 07, 2022
Lightwood is Legos for Machine Learning.

Lightwood is like Legos for Machine Learning. A Pytorch based framework that breaks down machine learning problems into smaller blocks that can be glu

MindsDB Inc 312 Jan 08, 2023
Codes for NAACL 2021 Paper "Unsupervised Multi-hop Question Answering by Question Generation"

Unsupervised-Multi-hop-QA This repository contains code and models for the paper: Unsupervised Multi-hop Question Answering by Question Generation (NA

Liangming Pan 70 Nov 27, 2022
Just-Now - This Is Just Now Login Friendlist Cloner Tools

JUST NOW LOGIN FRIENDLIST CLONER TOOLS Install $ apt update $ apt upgrade $ apt

MAHADI HASAN AFRIDI 21 Mar 09, 2022