CoaT: Co-Scale Conv-Attentional Image Transformers

Related tags

Deep LearningCoaT
Overview

CoaT: Co-Scale Conv-Attentional Image Transformers

Introduction

This repository contains the official code and pretrained models for CoaT: Co-Scale Conv-Attentional Image Transformers. It introduces (1) a co-scale mechanism to realize fine-to-coarse, coarse-to-fine and cross-scale attention modeling and (2) an efficient conv-attention module to realize relative position encoding in the factorized attention.

Model Accuracy

For more details, please refer to CoaT: Co-Scale Conv-Attentional Image Transformers by Weijian Xu*, Yifan Xu*, Tyler Chang, and Zhuowen Tu.

Changelog

04/23/2021: Pre-trained checkpoint for CoaT-Lite Mini is released.
04/22/2021: Code and pre-trained checkpoint for CoaT-Lite Tiny are released.

Usage

Environment Preparation

  1. Set up a new conda environment and activate it.

    # Create an environment with Python 3.8.
    conda create -n coat python==3.8
    conda activate coat
  2. Install required packages.

    # Install PyTorch 1.7.1 w/ CUDA 11.0.
    pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
    
    # Install timm 0.3.2.
    pip install timm==0.3.2
    
    # Install einops.
    pip install einops

Code and Dataset Preparation

  1. Clone the repo.

    git clone https://github.com/mlpc-ucsd/CoaT
    cd CoaT
  2. Download ImageNet dataset (ILSVRC 2012) and extract.

    # Create dataset folder.
    mkdir -p ./data/ImageNet
    
    # Download the dataset (not shown here) and copy the files (assume the download path is in $DATASET_PATH).
    cp $DATASET_PATH/ILSVRC2012_img_train.tar $DATASET_PATH/ILSVRC2012_img_val.tar $DATASET_PATH/ILSVRC2012_devkit_t12.tar.gz ./data/ImageNet
    
    # Extract the dataset.
    python -c "from torchvision.datasets import ImageNet; ImageNet('./data/ImageNet', split='train')"
    python -c "from torchvision.datasets import ImageNet; ImageNet('./data/ImageNet', split='val')"
    # After the extraction, you should observe `train` and `val` folders under ./data/ImageNet.

Evaluate Pre-trained Checkpoint

We provide the CoaT checkpoints pre-trained on the ImageNet dataset.

Name [email protected] [email protected] #Params SHA-256 (first 8 chars) URL
CoaT-Lite Tiny 77.5 93.8 5.7M e88e96b0 model, log
CoaT-Lite Mini 79.1 94.5 11M 6b4a8ae5 model, log

The following commands provide an example (CoaT-Lite Tiny) to evaluate the pre-trained checkpoint.

# Download the pretrained checkpoint.
mkdir -p ./output/pretrained
wget http://vcl.ucsd.edu/coat/pretrained/coat_lite_tiny_e88e96b0.pth -P ./output/pretrained
sha256sum ./output/pretrained/coat_lite_tiny_e88e96b0.pth  # Make sure it matches the SHA-256 hash (first 8 characters) in the table.

# Evaluate.
# Usage: bash ./scripts/eval.sh [model name] [output folder] [checkpoint path]
bash ./scripts/eval.sh coat_lite_tiny coat_lite_tiny_pretrained ./output/pretrained/coat_lite_tiny_e88e96b0.pth
# It should output results similar to "[email protected] 77.504 [email protected] 93.814" at very last.

Train

The following commands provide an example (CoaT-Lite Tiny, 8-GPU) to train the CoaT model.

# Usage: bash ./scripts/train.sh [model name] [output folder]
bash ./scripts/train.sh coat_lite_tiny coat_lite_tiny

Evaluate

The following commands provide an example (CoaT-Lite Tiny) to evaluate the checkpoint after training.

# Usage: bash ./scripts/eval.sh [model name] [output folder] [checkpoint path]
bash ./scripts/eval.sh coat_lite_tiny coat_lite_tiny_eval ./output/coat_lite_tiny/checkpoints/checkpoint0299.pth

Citation

@misc{xu2021coscale,
      title={Co-Scale Conv-Attentional Image Transformers}, 
      author={Weijian Xu and Yifan Xu and Tyler Chang and Zhuowen Tu},
      year={2021},
      eprint={2104.06399},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

This repository is released under the Apache License 2.0. License can be found in LICENSE file.

Acknowledgment

Thanks to DeiT and pytorch-image-models for a clear and data-efficient implementation of ViT. Thanks to lucidrains' implementation of Lambda Networks and CPVT.

Owner
mlpc-ucsd
mlpc-ucsd
Differentiable Neural Computers, Sparse Access Memory and Sparse Differentiable Neural Computers, for Pytorch

Differentiable Neural Computers and family, for Pytorch Includes: Differentiable Neural Computers (DNC) Sparse Access Memory (SAM) Sparse Differentiab

ixaxaar 302 Dec 14, 2022
Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic Scenes", ICCV 2021.

Deep 3D Mask Volume for View Synthesis of Dynamic Scenes Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic S

Ken Lin 17 Oct 12, 2022
CIFAR-10 Photo Classification

Image-Classification CIFAR-10 Photo Classification CIFAR-10_Dataset_Classfication CIFAR-10 Photo Classification Dataset CIFAR is an acronym that stand

ADITYA SHAH 1 Jan 05, 2022
An efficient framework for reinforcement learning.

rl: An efficient framework for reinforcement learning Requirements Introduction PPO Test Requirements name version Python =3.7 numpy =1.19 torch =1

16 Nov 30, 2022
This repository provides the official code for GeNER (an automated dataset Generation framework for NER).

GeNER This repository provides the official code for GeNER (an automated dataset Generation framework for NER). Overview of GeNER GeNER allows you to

DMIS Laboratory - Korea University 50 Nov 30, 2022
Code for the Higgs Boson Machine Learning Challenge organised by CERN & EPFL

A method to solve the Higgs boson challenge using Least Squares - Novae This project is the Project 1 of EPFL CS-433 Machine Learning. The project is

Giacomo Orsi 1 Nov 09, 2021
Object detection GUI based on PaddleDetection

PP-Tracking GUI界面测试版 本项目是基于飞桨开源的实时跟踪系统PP-Tracking开发的可视化界面 在PaddlePaddle中加入pyqt进行GUI页面研发,可使得整个训练过程可视化,并通过GUI界面进行调参,模型预测,视频输出等,通过多种类型的识别,简化整体预测流程。 GUI界面

杨毓栋 68 Jan 02, 2023
Python3 / PyTorch implementation of the following paper: Fine-grained Semantics-aware Representation Enhancement for Self-supervisedMonocular Depth Estimation. ICCV 2021 (oral)

FSRE-Depth This is a Python3 / PyTorch implementation of FSRE-Depth, as described in the following paper: Fine-grained Semantics-aware Representation

77 Dec 28, 2022
The Dual Memory is build from a simple CNN for the deep memory and Linear Regression fro the fast Memory

Simple-DMA a simple Dual Memory Architecture for classifications. based on the paper Dual-Memory Deep Learning Architectures for Lifelong Learning of

1 Jan 27, 2022
(NeurIPS 2021) Pytorch implementation of paper "Re-ranking for image retrieval and transductive few-shot classification"

SSR (NeurIPS 2021) Pytorch implementation of paper "Re-ranking for image retrieval and transductivefew-shot classification" [Paper] [Project webpage]

xshen 29 Dec 06, 2022
[2021 MultiMedia] CONQUER: Contextual Query-aware Ranking for Video Corpus Moment Retrieval

CONQUER: Contexutal Query-aware Ranking for Video Corpus Moment Retreival PyTorch implementation of CONQUER: Contexutal Query-aware Ranking for Video

Hou zhijian 23 Dec 26, 2022
Human Activity Recognition example using TensorFlow on smartphone sensors dataset and an LSTM RNN. Classifying the type of movement amongst six activity categories - Guillaume Chevalier

LSTMs for Human Activity Recognition Human Activity Recognition (HAR) using smartphones dataset and an LSTM RNN. Classifying the type of movement amon

Guillaume Chevalier 3.1k Dec 30, 2022
Unifying Global-Local Representations in Salient Object Detection with Transformer

GLSTR (Global-Local Saliency Transformer) This is the official implementation of paper "Unifying Global-Local Representations in Salient Object Detect

11 Aug 24, 2022
[NeurIPS 2021] Code for Unsupervised Learning of Compositional Energy Concepts

Unsupervised Learning of Compositional Energy Concepts This is the pytorch code for the paper Unsupervised Learning of Compositional Energy Concepts.

45 Nov 30, 2022
A pytorch-based real-time segmentation model for autonomous driving

CFPNet: Channel-Wise Feature Pyramid for Real-Time Semantic Segmentation This project contains the Pytorch implementation for the proposed CFPNet: pap

342 Dec 22, 2022
Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

Memformer - Pytorch Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attentio

Phil Wang 60 Nov 06, 2022
Machine Learning Time-Series Platform

cesium: Open-Source Platform for Time Series Inference Summary cesium is an open source library that allows users to: extract features from raw time s

632 Dec 26, 2022
Cross-Image Region Mining with Region Prototypical Network for Weakly Supervised Segmentation

Cross-Image Region Mining with Region Prototypical Network for Weakly Supervised Segmentation The code of: Cross-Image Region Mining with Region Proto

LiuWeide 16 Nov 26, 2022
Vision transformers (ViTs) have found only limited practical use in processing images

CXV Convolutional Xformers for Vision Vision transformers (ViTs) have found only limited practical use in processing images, in spite of their state-o

Cloudwalker 23 Sep 10, 2022
Transfer-Learn is an open-source and well-documented library for Transfer Learning.

Transfer-Learn is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consist

THUML @ Tsinghua University 2.2k Jan 03, 2023