Official code for the paper "Why Do Self-Supervised Models Transfer? Investigating the Impact of Invariance on Downstream Tasks".

Overview

Why Do Self-Supervised Models Transfer? Investigating the Impact of Invariance on Downstream Tasks

This repository contains the official code for the paper Why Do Self-Supervised Models Transfer? Investigating the Impact of Invariance on Downstream Tasks.

Requirements

This codebase has been tested with the following package versions:

python=3.8.8
torch=1.9.0+cu102
torchvision=0.10.0+cu102
PIL=8.1.0
numpy=1.19.2
scipy=1.6.1
tqdm=4.57.0
sklearn=0.24.1
albumentations=1.0.3

Prepare data

There are several classes defined in the datasets directory. The data is expected in a directory name data, located on the same level as this repository. Below is an outline of the expected file structure:

data/
    imagenet/
    CIFAR10/
    300W/
    ...
ssl-invariances/
    datasets/
    models/
    readme.md
    ...

For synthetic invariance evaluation, get the ILSVRC2012 validation data from https://image-net.org/ and store in ../data/imagenet/val/.

For real-world invariances, download the following datasets: Flickr1024, COIL-100, ALOI, ALOT, DaLI, ExposureErrors, RealBlur.

For extrinsic invariances, get Causal3DIdent.

Finally, our downstream datasets are CIFAR10, Caltech101, Flowers, 300W, CelebA, LSPose.

Pre-training models

We pre-train several models based on the MoCo codebase.

To set up a version of the codebase that can pre-train our models, first clone the MoCo repo onto the same level as this repo:

git clone https://github.com/facebookresearch/moco

This should be the resulting file structure:

data/
ssl-invariances/
moco/

Then copy the files from ssl-invariances/pretraining/ into the cloned repo:

cp ssl-invariances/pretraining/* moco/

Finally, to run our models, enter the cloned repo by cd moco and run one of the following:

# train the Default model
python main_moco.py -a resnet50 --model default --lr 0.03 --batch-size 256 --mlp --moco-t 0.2 --cos --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 ../data/imagenet

# train the Ventral model
python main_moco.py -a resnet50 --model ventral --lr 0.03 --batch-size 256 --mlp --moco-t 0.2 --cos --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 ../data/imagenet

# train the Dorsal model
python main_moco.py -a resnet50 --model dorsal --lr 0.03 --batch-size 256 --mlp --moco-t 0.2 --cos --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 ../data/imagenet

# train the Default(x3) model
python main_moco.py -a resnet50w3 --model default --moco-dim 384 --lr 0.03 --batch-size 256 --mlp --moco-t 0.2 --cos --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 ../data/imagenet

This will train the models for 200 epochs and save checkpoints. When training has completed, the final model checkpoint, e.g. default_00199.pth.tar, should be moved to ssl-invariances/models/default.pth.tarfor use in evaluation in the below code.

The rest of this codebase assumes these final model checkpoints are located in a directory called ssl-invariances/models/ as shown below.

ssl-invariances/
    models/
        default.pth.tar
        default_w3.pth.tar
        dorsal.pth.tar
        ventral.pth.tar

Synthetic invariance

To evaluate the Default model on grayscale invariance, run:

python eval_synthetic_invariance.py --model default --transform grayscale ../data/imagenet

This will compute the mean and covariance of the model's feature space and save these statistics in the results/ directory. These are then used to speed up future invariance computations for the same model.

Real-world invariance

To evaluate the Ventral model on COIL100 viewpoint invariance, run:

python eval_realworld_invariance.py --model ventral --dataset COIL100

Extrinsic invariance on Causal3DIdent

To evaluate the Dorsal model on Causal3DIdent object x position prediction, run:

python eval_causal3dident.py --model dorsal --target 0

Downstream performance

To evaluate the combined Def+Ven+Dor model on 300W facial landmark regression, run:

python eval_downstream.py --model default+ventral+dorsal --dataset 300w

Citation

If you find our work useful for your research, please consider citing our paper:

@misc{ericsson2021selfsupervised,
      title={Why Do Self-Supervised Models Transfer? Investigating the Impact of Invariance on Downstream Tasks}, 
      author={Linus Ericsson and Henry Gouk and Timothy M. Hospedales},
      year={2021},
      eprint={2111.11398},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

If you have any questions, feel welcome to create an issue or contact Linus Ericsson ([email protected]).

Owner
Linus Ericsson
PhD student in the Data Science CDT at The University of Edinburgh
Linus Ericsson
Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting

Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting This is the origin Pytorch implementation of Informer in the followin

Haoyi 3.1k Dec 29, 2022
This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees

Mega-NeRF This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees used by the Mega-NeRF-Dynamic viewe

cmusatyalab 260 Dec 28, 2022
SBINN: Systems-biology informed neural network

SBINN: Systems-biology informed neural network The source code for the paper M. Daneker, Z. Zhang, G. E. Karniadakis, & L. Lu. Systems biology: Identi

Lu Group 15 Nov 19, 2022
DrQ-v2: Improved Data-Augmented Reinforcement Learning

DrQ-v2: Improved Data-Augmented RL Agent Method DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ,

Facebook Research 234 Jan 01, 2023
A dataset for online Arabic calligraphy

Calliar Calliar is a dataset for Arabic calligraphy. The dataset consists of 2500 json files that contain strokes manually annotated for Arabic callig

ARBML 114 Dec 28, 2022
Localization Distillation for Object Detection

Localization Distillation for Object Detection This repo is based on mmDetection. This is the code for our paper: Localization Distillation

274 Dec 26, 2022
π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis

π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis Project Page | Paper | Data Eric Ryan Chan*, Marco Monteiro*, Pe

375 Dec 31, 2022
Robot Hacking Manual (RHM). From robotics to cybersecurity. Papers, notes and writeups from a journey into robot cybersecurity.

RHM: Robot Hacking Manual Download in PDF RHM v0.4 ┃ Read online The Robot Hacking Manual (RHM) is an introductory series about cybersecurity for robo

Víctor Mayoral Vilches 233 Dec 30, 2022
Code needed to reproduce the examples found in "The Temporal Robustness of Stochastic Signals"

The Temporal Robustness of Stochastic Signals Code needed to reproduce the examples found in "The Temporal Robustness of Stochastic Signals" Case stud

0 Oct 28, 2021
Simple and ready-to-use tutorials for TensorFlow

TensorFlow World To support maintaining and upgrading this project, please kindly consider Sponsoring the project developer. Any level of support is a

Amirsina Torfi 4.5k Dec 23, 2022
Customised to detect objects automatically by a given model file(onnx)

LabelImg LabelImg is a graphical image annotation tool. It is written in Python and uses Qt for its graphical interface. Annotations are saved as XML

Heeone Lee 1 Jun 07, 2022
GitHub repository for "Improving Video Generation for Multi-functional Applications"

Improving Video Generation for Multi-functional Applications GitHub repository for "Improving Video Generation for Multi-functional Applications" Pape

Bernhard Kratzwald 328 Dec 07, 2022
Deep Learning Models for Causal Inference

Extensive tutorials for learning how to build deep learning models for causal inference using selection on observables in Tensorflow 2.

Bernard J Koch 151 Dec 31, 2022
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 01, 2023
This is a collection of our NAS and Vision Transformer work.

AutoML - Neural Architecture Search This is a collection of our AutoML-NAS work iRPE (NEW): Rethinking and Improving Relative Position Encoding for Vi

Microsoft 828 Dec 28, 2022
U-Net for GBM

My Final Year Project(FYP) In National University of Singapore(NUS) You need Pytorch(stable 1.9.1) Both cuda version and cpu version are OK File Str

PinkR1ver 1 Oct 27, 2021
U-Time: A Fully Convolutional Network for Time Series Segmentation

U-Time & U-Sleep Official implementation of The U-Time [1] model for general-purpose time-series segmentation. The U-Sleep [2] model for resilient hig

Mathias Perslev 176 Dec 19, 2022
Agile SVG maker for python

Agile SVG Maker Need to draw hundreds of frames for a GIF? Need to change the style of all pictures in a PPT? Need to draw similar images with differe

SemiWaker 4 Sep 25, 2022
Omniscient Video Super-Resolution

Omniscient Video Super-Resolution This is the official code of OVSR (Omniscient Video Super-Resolution, ICCV 2021). This work is based on PFNL. Datase

36 Oct 27, 2022
Repository for the paper "PoseAug: A Differentiable Pose Augmentation Framework for 3D Human Pose Estimation", CVPR 2021.

PoseAug: A Differentiable Pose Augmentation Framework for 3D Human Pose Estimation Code repository for the paper: PoseAug: A Differentiable Pose Augme

Pyjcsx 328 Dec 17, 2022