PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Related tags

Deep Learningdino
Overview

Self-Supervised Vision Transformers with DINO

PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supervised Vision Transformers.
[blogpost] [arXiv]

DINO illustration

Pretrained models

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the training and evaluation logs.

arch params k-nn linear download
DeiT-S/16 21M 74.5% 77.0% backbone only full checkpoint args logs eval logs
DeiT-S/8 21M 78.3% 79.7% backbone only full checkpoint args logs eval logs
ViT-B/16 85M 76.1% 78.2% backbone only full checkpoint args logs eval logs
ViT-B/8 85M 77.4% 80.1% backbone only full checkpoint args logs eval logs
ResNet-50 23M 67.5% 75.3% backbone only full checkpoint args logs eval logs

The pretrained models are available on PyTorch Hub.

import torch
deits16 = torch.hub.load('facebookresearch/dino', 'dino_deits16')
deits8 = torch.hub.load('facebookresearch/dino', 'dino_deits8')
vitb16 = torch.hub.load('facebookresearch/dino', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino', 'dino_vitb8')
resnet50 = torch.hub.load('facebookresearch/dino', 'dino_resnet50')

Training

Documentation

Please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the args column of the pretrained models section. For a glimpse at the full documentation of DINO training please run:

python main_dino.py --help

Vanilla DINO training 🦕

Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach ~69.3% on k-NN eval and ~73.8% on linear eval. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Multi-node training

We use Slurm and submitit (pip install submitit). To train on 2 nodes with 8 GPUs each (total 16 GPUs):

python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
DINO with ViT-base network.
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base  --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Boosting DINO performance 🦖

You can improve the performance of the vanilla run by:

  • training for more epochs: --epochs 300,
  • increasing the teacher temperature: --teacher_temp 0.07 --warmup_teacher_temp_epochs 30.
  • removing last layer normalization (only safe with --arch deit_small): --norm_last_layer false,
Full command.
python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

The resulting pretrained model should reach ~73.4% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

ResNet-50 and other convnets trainings

This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example here is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs:

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Evaluation: k-NN classification on ImageNet

To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet

If you choose not to specify --pretrained_weights, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 

Evaluation: Linear classification on ImageNet

To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet

Self-attention visualization

You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:

python visualize_attention.py
Self-attention from a Vision Transformer with 8x8 patches trained with DINO

License

See the LICENSE file for more details.

Citation

If you find this repository useful, please consider giving a star and citation 🦖 :

@article{caron2021emerging,
  title={Emerging Properties in Self-Supervised Vision Transformers},
  author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e  and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
  journal={arXiv preprint arXiv:2104.14294},
  year={2021}
}
Owner
Facebook Research
Facebook Research
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 21.3k Jan 01, 2023
A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation.

TiSASRec.paddle A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation. Introduction 论文:Time Interval Aware Sel

Paddorch 2 Nov 28, 2021
Speech Recognition using DeepSpeech2.

deepspeech.pytorch Implementation of DeepSpeech2 for PyTorch using PyTorch Lightning. The repo supports training/testing and inference using the DeepS

Sean Naren 2k Jan 04, 2023
Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Introdunction This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification". Abstract This pa

Shilong Liu 274 Dec 28, 2022
Wenet STT Python

Wenet STT Python Beta Software Simple Python library, distributed via binary wheels with few direct dependencies, for easily using WeNet models for sp

David Zurow 33 Feb 21, 2022
Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Fisher Induced Sparse uncHanging (FISH) Mask This repo contains the code for Fisher Induced Sparse uncHanging (FISH) Mask training, from "Training Neu

Varun Nair 37 Dec 30, 2022
UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

UnivNet UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation. Training python train.py --c

Rishikesh (ऋषिकेश) 55 Dec 26, 2022
code for TCL: Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022 News (03/16/2022) upload retrieval checkpoints finetuned on COCO and Flickr T

187 Jan 02, 2023
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
Implementation of light baking system for ray tracing based on Activision's UberBake

Vulkan Light Bakary MSU Graphics Group Student's Diploma Project Treefonov Andrey [GitHub] [LinkedIn] Project Goal The goal of the project is to imple

Andrey Treefonov 7 Dec 27, 2022
Graph Convolutional Networks in PyTorch

Graph Convolutional Networks in PyTorch PyTorch implementation of Graph Convolutional Networks (GCNs) for semi-supervised classification [1]. For a hi

Thomas Kipf 4.5k Dec 31, 2022
Outlier Exposure with Confidence Control for Out-of-Distribution Detection

OOD-detection-using-OECC This repository contains the essential code for the paper Outlier Exposure with Confidence Control for Out-of-Distribution De

Nazim Shaikh 64 Nov 02, 2022
Source code of the paper PatchGraph: In-hand tactile tracking with learned surface normals.

PatchGraph This repository contains the source code of the paper PatchGraph: In-hand tactile tracking with learned surface normals. Installation Creat

Paloma Sodhi 11 Dec 15, 2022
source code for 'Finding Valid Adjustments under Non-ignorability with Minimal DAG Knowledge' by A. Shah, K. Shanmugam, K. Ahuja

Source code for "Finding Valid Adjustments under Non-ignorability with Minimal DAG Knowledge" Reference: Abhin Shah, Karthikeyan Shanmugam, Kartik Ahu

Abhin Shah 1 Jun 03, 2022
Introduction to CPM

CPM CPM is an open-source program on large-scale pre-trained models, which is conducted by Beijing Academy of Artificial Intelligence and Tsinghua Uni

Tsinghua AI 136 Dec 23, 2022
Multi-Scale Aligned Distillation for Low-Resolution Detection (CVPR2021)

MSAD Multi-Scale Aligned Distillation for Low-Resolution Detection Lu Qi*, Jason Kuen*, Jiuxiang Gu, Zhe Lin, Yi Wang, Yukang Chen, Yanwei Li, Jiaya J

Jia Research Lab 115 Dec 23, 2022
TorchOk - The toolkit for fast Deep Learning experiments in Computer Vision

TorchOk - The toolkit for fast Deep Learning experiments in Computer Vision

52 Dec 23, 2022
Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images

Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images In this paper, we present an effective Dynamic Enhancement Anchor

13 Dec 09, 2022
Justmagic - Use a function as a method with this mystic script, like in Nim

justmagic Use a function as a method with this mystic script, like in Nim. Just

witer33 8 Oct 08, 2022
Official implementation for paper: A Latent Transformer for Disentangled Face Editing in Images and Videos.

A Latent Transformer for Disentangled Face Editing in Images and Videos Official implementation for paper: A Latent Transformer for Disentangled Face

InterDigital 108 Dec 09, 2022