Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

Overview

CrossViT

This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv

If you use the codes and models from this repo, please cite our work. Thanks!

@inproceedings{
    chen2021crossvit,
    title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
    author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
    booktitle={International Conference on Computer Vision (ICCV)},
    year={2021}
}

Installation

To install requirements:

pip install -r requirements.txt

With conda:

conda create -n crossvit python=3.8
conda activate crossvit
conda install pytorch=1.7.1 torchvision  cudatoolkit=11.0 -c pytorch -c nvidia
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Pretrained models

We provide models trained on ImageNet1K. You can find models here. And you can load pretrained weights into models by add --pretrained flag.

Training

To train crossvit_9_dagger_224 on ImageNet on a single node with 8 gpus for 300 epochs run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Other model names can be found at models/crossvit.py.

Multinode training

Distributed training is available via Slurm and submitit:

To train a crossvit_9_dagger_224 model on ImageNet on 4 nodes with 8 gpus each for 300 epochs:

python run_with_submitit.py --nodes 4 --model crossvit_9_dagger_224 --data-path /path/to/imagenet --batch-size 128 --warmup-epochs 30

Or you can start process on each machine maunally. E.g. 2 nodes, each with 8 gpus.

Machine 0:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=0 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Machine 1:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=1 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Note that: some slurm configurations might need to be changed based on your cluster.

Evaluation

To evaluate a pretrained model on crossvit_9_dagger_224:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 128 --data-path /path/to/imagenet --eval --pretrained
You might also like...
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)
Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)

MUSIQ: Multi-Scale Image Quality Transformer Unofficial pytorch implementation of the paper "MUSIQ: Multi-Scale Image Quality Transformer" (paper link

Official implementation of
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Official implementation of paper
Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Introdunction This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification". Abstract This pa

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Comments
  • Multilabel classification

    Multilabel classification

    @jjasghar @krook @chunfuchen thanks a lot for sharing the code , i have a problem statement which i want to train crossvit on please find the details below

    1. input is of image shape 256x128 having label vector as gt for multilabel classification
    2. can i use crossvit to train the model , what all modification has to be done with the code base ??

    THnaks for the support

    opened by abhigoku10 4
  • Honoring distributed flag + fixing reset_classifier

    Honoring distributed flag + fixing reset_classifier

    1. Honoring the args.distributed flag in calls to evaluate().
    2. A couple of changes to make the reset_classifier() method work:
    • Initializing the embed_dim instance variable in VisionTransformer.
    • Reinitializing the classification head for all branches.
    opened by abhrac 1
  • Parameter setting

    Parameter setting

    Hello, thank you for your excellent work, I would like to know how you set the parameters on the CIFAR10 dataset, mainly the size of the patch,Looking forward to your reply

    opened by happy20200 1
Owner
International Business Machines
International Business Machines
An implementation of Deep Forest 2021.2.1.

Deep Forest (DF) 21 DF21 is an implementation of Deep Forest 2021.2.1. It is designed to have the following advantages: Powerful: Better accuracy than

LAMDA Group, Nanjing University 795 Jan 03, 2023
Compact Bidirectional Transformer for Image Captioning

Compact Bidirectional Transformer for Image Captioning Requirements Python 3.8 Pytorch 1.6 lmdb h5py tensorboardX Prepare Data Please use git clone --

YE Zhou 19 Dec 12, 2022
A Re-implementation of the paper "A Deep Learning Framework for Character Motion Synthesis and Editing"

What is This This is a simple re-implementation of the paper "A Deep Learning Framework for Character Motion Synthesis and Editing"(1). Only Sections

102 Dec 14, 2022
QTool: A Low-bit Quantization Toolbox for Deep Neural Networks in Computer Vision

This project provides abundant choices of quantization strategies (such as the quantization algorithms, training schedules and empirical tricks) for quantizing the deep neural networks into low-bit c

Monash Green AI Lab 51 Dec 10, 2022
A modular application for performing anomaly detection in networks

Deep-Learning-Models-for-Network-Annomaly-Detection The modular app consists for mainly three annomaly detection algorithms. The system supports model

Shivam Patel 1 Dec 09, 2021
Tensorflow-Project-Template - A best practice for tensorflow project template architecture.

Tensorflow Project Template A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributi

Mahmoud G. Salem 3.6k Dec 22, 2022
Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems

Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems This is our experimental code for RecSys 2021 paper "Learning

11 Jul 28, 2022
Adaptive Attention Span for Reinforcement Learning

Adaptive Transformers in RL Official implementation of Adaptive Transformers in RL In this work we replicate several results from Stabilizing Transfor

100 Nov 15, 2022
A research toolkit for particle swarm optimization in Python

PySwarms is an extensible research toolkit for particle swarm optimization (PSO) in Python. It is intended for swarm intelligence researchers, practit

Lj Miranda 1k Dec 30, 2022
Akshat Surolia 2 May 11, 2022
This project is based on RIFE and aims to make RIFE more practical for users by adding various features and design new models

CPM 项目描述 CPM(Chinese Pretrained Models)模型是北京智源人工智能研究院和清华大学发布的中文大规模预训练模型。官方发布了三种规模的模型,参数量分别为109M、334M、2.6B,用户需申请与通过审核,方可下载。 由于原项目需要考虑大模型的训练和使用,需要安装较为复杂

hzwer 190 Jan 08, 2023
Residual Pathway Priors for Soft Equivariance Constraints

Residual Pathway Priors for Soft Equivariance Constraints This repo contains the implementation and the experiments for the paper Residual Pathway Pri

Marc Finzi 13 Oct 12, 2022
Posterior temperature optimized Bayesian models for inverse problems in medical imaging

Posterior temperature optimized Bayesian models for inverse problems in medical imaging Max-Heinrich Laves*, Malte Tölle*, Alexander Schlaefer, Sandy

Artificial Intelligence in Cardiovascular Medicine (AICM) 6 Sep 19, 2022
Implementation for Learning to Track with Object Permanence

Learning to Track with Object Permanence A video-based MOT approach capable of tracking through full occlusions: Learning to Track with Object Permane

Toyota Research Institute - Machine Learning 91 Jan 03, 2023
CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer

CSAW-M This repository contains code for CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer. Source code for tr

Yue Liu 7 Oct 11, 2022
Adaptive Denoising Training (ADT) for Recommendation.

DenoisingRec Adaptive Denoising Training for Recommendation. This is the pytorch implementation of our paper at WSDM 2021: Denoising Implicit Feedback

Wenjie Wang 51 Dec 30, 2022
The code for replicating the experiments from the LFI in SSMs with Unknown Dynamics paper.

Likelihood-Free Inference in State-Space Models with Unknown Dynamics This package contains the codes required to run the experiments in the paper. Th

Alex Aushev 0 Dec 27, 2021
Codes for "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation"

CSDI This is the github repository for the NeurIPS 2021 paper "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation

106 Jan 04, 2023
The official PyTorch code implementation of "Personalized Trajectory Prediction via Distribution Discrimination" in ICCV 2021.

Personalized Trajectory Prediction via Distribution Discrimination (DisDis) The official PyTorch code implementation of "Personalized Trajectory Predi

25 Dec 20, 2022
PyTorch implementation of the paper Dynamic Token Normalization Improves Vision Transfromers.

Dynamic Token Normalization Improves Vision Transformers This is the PyTorch implementation of the paper Dynamic Token Normalization Improves Vision T

Wenqi Shao 20 Oct 09, 2022