Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

Overview

CrossViT

This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv

If you use the codes and models from this repo, please cite our work. Thanks!

@inproceedings{
    chen2021crossvit,
    title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
    author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
    booktitle={International Conference on Computer Vision (ICCV)},
    year={2021}
}

Installation

To install requirements:

pip install -r requirements.txt

With conda:

conda create -n crossvit python=3.8
conda activate crossvit
conda install pytorch=1.7.1 torchvision  cudatoolkit=11.0 -c pytorch -c nvidia
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Pretrained models

We provide models trained on ImageNet1K. You can find models here. And you can load pretrained weights into models by add --pretrained flag.

Training

To train crossvit_9_dagger_224 on ImageNet on a single node with 8 gpus for 300 epochs run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Other model names can be found at models/crossvit.py.

Multinode training

Distributed training is available via Slurm and submitit:

To train a crossvit_9_dagger_224 model on ImageNet on 4 nodes with 8 gpus each for 300 epochs:

python run_with_submitit.py --nodes 4 --model crossvit_9_dagger_224 --data-path /path/to/imagenet --batch-size 128 --warmup-epochs 30

Or you can start process on each machine maunally. E.g. 2 nodes, each with 8 gpus.

Machine 0:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=0 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Machine 1:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=1 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Note that: some slurm configurations might need to be changed based on your cluster.

Evaluation

To evaluate a pretrained model on crossvit_9_dagger_224:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 128 --data-path /path/to/imagenet --eval --pretrained
You might also like...
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)
Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)

MUSIQ: Multi-Scale Image Quality Transformer Unofficial pytorch implementation of the paper "MUSIQ: Multi-Scale Image Quality Transformer" (paper link

Official implementation of
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Official implementation of paper
Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Introdunction This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification". Abstract This pa

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Comments
  • Multilabel classification

    Multilabel classification

    @jjasghar @krook @chunfuchen thanks a lot for sharing the code , i have a problem statement which i want to train crossvit on please find the details below

    1. input is of image shape 256x128 having label vector as gt for multilabel classification
    2. can i use crossvit to train the model , what all modification has to be done with the code base ??

    THnaks for the support

    opened by abhigoku10 4
  • Honoring distributed flag + fixing reset_classifier

    Honoring distributed flag + fixing reset_classifier

    1. Honoring the args.distributed flag in calls to evaluate().
    2. A couple of changes to make the reset_classifier() method work:
    • Initializing the embed_dim instance variable in VisionTransformer.
    • Reinitializing the classification head for all branches.
    opened by abhrac 1
  • Parameter setting

    Parameter setting

    Hello, thank you for your excellent work, I would like to know how you set the parameters on the CIFAR10 dataset, mainly the size of the patch,Looking forward to your reply

    opened by happy20200 1
Owner
International Business Machines
International Business Machines
A-SDF: Learning Disentangled Signed Distance Functions for Articulated Shape Representation (ICCV 2021)

A-SDF: Learning Disentangled Signed Distance Functions for Articulated Shape Representation (ICCV 2021) This repository contains the official implemen

81 Dec 14, 2022
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

ELECTRA Introduction ELECTRA is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using

Google Research 2.1k Dec 28, 2022
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
Data and Code for ACL 2021 Paper "Inter-GPS: Interpretable Geometry Problem Solving with Formal Language and Symbolic Reasoning"

Introduction Code and data for ACL 2021 Paper "Inter-GPS: Interpretable Geometry Problem Solving with Formal Language and Symbolic Reasoning". We cons

Pan Lu 81 Dec 27, 2022
ELSED: Enhanced Line SEgment Drawing

ELSED: Enhanced Line SEgment Drawing This repository contains the source code of ELSED: Enhanced Line SEgment Drawing the fastest line segment detecto

Iago Suárez 125 Dec 31, 2022
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives

Status: Under development (expect bug fixes and huge updates) ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectiv

37 Dec 28, 2022
Official PyTorch implementation of "Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning" (ICCV2021 Oral)

MeTAL - Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning (ICCV2021 Oral) Sungyong Baik, Janghoon Choi, Heewon Kim, Dohee Cho, Jaes

Sungyong Baik 44 Dec 29, 2022
[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation

Few-shot 3D Point Cloud Semantic Segmentation Created by Na Zhao from National University of Singapore Introduction This repository contains the PyTor

117 Dec 27, 2022
A curated list of references for MLOps

A curated list of references for MLOps

Larysa Visengeriyeva 9.3k Jan 07, 2023
CenterPoint 3D Object Detection and Tracking using center points in the bird-eye view.

CenterPoint 3D Object Detection and Tracking using center points in the bird-eye view. Center-based 3D Object Detection and Tracking, Tianwei Yin, Xin

Tianwei Yin 134 Dec 23, 2022
Rational Activation Functions - Replacing Padé Activation Units

Rational Activations - Learnable Rational Activation Functions First introduce as PAU in Padé Activation Units: End-to-end Learning of Activation Func

<a href=[email protected]"> 38 Nov 22, 2022
Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral]

Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral] Learning to Disambiguate Strongly In

Zicong Fan 40 Dec 22, 2022
Simple PyTorch implementations of Badnets on MNIST and CIFAR10.

Simple PyTorch implementations of Badnets on MNIST and CIFAR10.

Vera 75 Dec 13, 2022
A web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks

This project is a web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks. Thanks for NVlabs' excelle

K.L. 150 Dec 15, 2022
A Light CNN for Deep Face Representation with Noisy Labels

A Light CNN for Deep Face Representation with Noisy Labels Citation If you use our models, please cite the following paper: @article{wulight, title=

Alfred Xiang Wu 715 Nov 05, 2022
Code for paper "Vocabulary Learning via Optimal Transport for Neural Machine Translation"

**Codebase and data are uploaded in progress. ** VOLT(-py) is a vocabulary learning codebase that allows researchers and developers to automaticaly ge

416 Jan 09, 2023
Fair Recommendation in Two-Sided Platforms

Fair Recommendation in Two-Sided Platforms

gourabgggg 1 Nov 10, 2021
Neurolab is a simple and powerful Neural Network Library for Python

Neurolab Neurolab is a simple and powerful Neural Network Library for Python. Contains based neural networks, train algorithms and flexible framework

152 Dec 06, 2022
Contrastive Language-Image Pretraining

CLIP [Blog] [Paper] [Model Card] [Colab] CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pair

OpenAI 11.5k Jan 08, 2023
Baseline inference Algorithm for the STOIC2021 challenge.

STOIC2021 Baseline Algorithm This codebase contains an example submission for the STOIC2021 COVID-19 AI Challenge. As a baseline algorithm, it impleme

Luuk Boulogne 10 Aug 08, 2022