Pytorch implementation of "Training a 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet"

Overview

Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet (arxiv)

This is a Pytorch implementation of our technical report.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Training Pipeline

Pipeline

Our codes are based on the pytorch-image-models by Ross Wightman.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-L 24 768 448 150.47M 86.2 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map here. As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Coming soon

Reference

If you use this repo or find it useful, please consider citing:

@misc{jiang2021token,
      title={Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet}, 
      author={Zihang Jiang and Qibin Hou and Li Yuan and Daquan Zhou and Xiaojie Jin and Anran Wang and Jiashi Feng},
      year={2021},
      eprint={2104.10858},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Related projects

T2T-ViT, Re-labeling ImageNet.

Comments
  • error: download the pretrained model but couldn't be unzipped

    error: download the pretrained model but couldn't be unzipped

    tar -xvf lvvit_s-26M-384-84-4.pth.tar tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    opened by Williamlizl 10
  • The accuracy of the validation set is 0,and the loss is always around 13

    The accuracy of the validation set is 0,and the loss is always around 13

    Hello! I use ILSVRC2012_img_train and ILSVRC2012_img_val, and use the provided label_top5_train_nfnet from Google Drive. I train lv-vit-s with batch_size 64 without apex for one epoch. Thanks for your advice.

    opened by yifanQi98 7
  • Pretrained weights for LV-ViT-T

    Pretrained weights for LV-ViT-T

    Hi,

    Thanks for sharing your work. Could you also provide the pre-trained weights for the LV-ViT-T model variant, the one that achieves 79.1% top1-acc. as mentioned in Table 1 of your paper?

    All the best, Marc

    opened by marc345 5
  • train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    Hi, thanks for you great work. When I train script, some error occurs: AttributeError: 'tuple' object has no attribute 'log_softmax'

    with amp_autocast():   
                output = model(input)  
                loss = loss_fn(output, target)  # error occurs
    
    

    and loss function is train_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.0).cuda()

    by the way: Could you please tell me why we need to specify smoothing=0.0?

    opened by lxy5513 5
  • RuntimeError: CUDA error: device-side assert triggered

    RuntimeError: CUDA error: device-side assert triggered

    I am a green hand of DL. When I run the code of volo with tlt in a single or multi GPU, I get an error as follows: /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [25,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. Traceback (most recent call last): File "main.py", line 949, in main() File "main.py", line 664, in main optimizers=optimizers) File "main.py", line 773, in train_one_epoch label_size=args.token_label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 90, in mixup_target y1 = get_labelmaps_with_coords(target, num_classes, on_value=on_value, off_value=off_value, device=device, label_size=label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 64, in get_labelmaps_with_coords num_classes=num_classes,device=device) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 16, in get_featuremaps _label_topk[1][:, :, :].long(), RuntimeError: CUDA error: device-side assert triggered.

    I can't fix this problem right now.

    opened by JIAOJIAYUASD 4
  • Generating label for custom dataset

    Generating label for custom dataset

    Hello,

    Thank you for sharing your work. I am currently trying to generate token label to a custom dataset for model lvvit_s, but I keep getting the loss close to 7 and the Accuracy 0 (not pre-trained and using 1 GPU in Google Colab). I also tried using the pre-trained model with --transfer but got 0 in both Loss and Acc . What option should I use for a custom dataset? image

    opened by AleMaiaF 2
  • generate_label.py unable to find model lvvit_s

    generate_label.py unable to find model lvvit_s

    Hi,

    When I tried to run the label generation script for the model lvvit_s it returned an error "RuntimeError: Unknown model".

    Solution: It worked when I added the line "import tlt.models" in the file generate_label.py.

    opened by AleMaiaF 2
  • Can Token labeling reach higher than annotator model?

    Can Token labeling reach higher than annotator model?

    Greetings,

    Thank you for this incredible research.

    I would like to know if it is possible to use Token Labeling to achieve scores higher than that of the annotator model, I believe this was the case with VOLO D5 model where it achieved higher score than NFNet, model used for annotation.

    opened by ErenBalatkan 1
  • label_map does not do the same augmentation (random crop) as the input image

    label_map does not do the same augmentation (random crop) as the input image

    Hi Thanks so much for the nice work! I am curious if you could share the insight on processing of the label_map. If I understand it correctly, after we load image and the corresponding, we shall do the same cropping/ flip/ resize, but in https://github.com/zihangJiang/TokenLabeling/blob/aa438eff9b9fc2daa8c8b4cc6bfaa6e3721f995e/tlt/data/label_transforms_factory.py#L58-L73 Seems only image was cropped, but the label map does not do the same cropping, which make the label map not match with the image?

    Shall we do

            return torchvision_F.resized_crop(
                    img, i, j, h, w, self.size, interpolation
            ), torchvision_F.resized_crop(
                    label_map, i / ratio, j / ratio, h / ratio, w / ratio, self.size, interpolation
            )
    

    Thanks

    opened by haooooooqi 1
  • Python3.6, ok; Python3.8, error

    Python3.6, ok; Python3.8, error

    Test: [ 0/1] Time: 11.293 (11.293) Loss: 0.7043 (0.7043) [email protected]: 42.1875 (42.1875) [email protected]: 100.0000 (100.0000) Test: [ 1/1] Time: 0.108 (5.701) Loss: 0.5847 (0.6689) [email protected]: 89.8148 (56.3187) [email protected]: 100.0000 (100.0000) free(): invalid pointer free(): invalid pointer Traceback (most recent call last): File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 303, in <module> main() File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 294, in main raise subprocess.CalledProcessError(returncode=process.returncode, subprocess.CalledProcessError: Command '['/opt/conda/bin/python3.8', '-u', 'main.py', '--local_rank=1', './dataset/c/c', '--model', 'lvvit_s', '-b', '128', '--apex-amp', '--img-size', '224', '--drop-path', '0.1', '--token-label', '--token-label-size', '14', '--dense-weight', '0.0', '--num-classes', '2', '--finetune', './pretrained/lvvit_s-26M-384-84-4.pth.tar']' died with <Signals.SIGABRT: 6>. [email protected]:/puxin_libochao/TokenLabeling# CUDA_VISIBLE_DEVICES=0,1 bash ./distributed_train.sh 2 ./dataset/c/c --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes 2 --finetune ./pretrained/lvvit_s-26M-384-84-4.pth.tar

    opened by Williamlizl 1
  • A Bag of Training Techniques for ViT

    A Bag of Training Techniques for ViT

    Hi, thanks for your wonderful work. I have a question that whether training techniques mentioned in the LV-Vit can be used in other downstream task like object detection? In your paper, I see that many of this techniques are used in ImageNet. Thanks!

    opened by qdd1234 1
  • how to apply token labeling to CNN ?

    how to apply token labeling to CNN ?

    Hello ~ I'm interested in your token labeling technique, So I want to apply this technique in CNN based model because ViT is very heavy to train.

    can I get the your code with CNN token labeling? if you're not give me some detail for implementing

    thank you.

    opened by HoJ00n2 0
  • Model settings for Cifar10

    Model settings for Cifar10

    I am interested if there is any LV-ViT- model setup you have tested for Cifar10. I would like to know the proper setup of all blocks in none pretrained weights settings.

    opened by Aminullah6264 0
Owner
蒋子航
Now a Ph.D. student supervised by Prof. Feng Jiashi in ECE, NUS.
蒋子航
A TensorFlow implementation of FCN-8s

FCN-8s implementation in TensorFlow Contents Overview Examples and demo video Dependencies How to use it Download pre-trained VGG-16 Overview This is

Pierluigi Ferrari 50 Aug 08, 2022
TensorFlow-based implementation of "Pyramid Scene Parsing Network".

PSPNet_tensorflow Important Code is fine for inference. However, the training code is just for reference and might be only used for fine-tuning. If yo

HsuanKung Yang 323 Dec 20, 2022
SASM - simple crossplatform IDE for NASM, MASM, GAS and FASM assembly languages

SASM (SimpleASM) - простая кроссплатформенная среда разработки для языков ассемблера NASM, MASM, GAS, FASM с подсветкой синтаксиса и отладчиком. В SA

Dmitriy Manushin 5.6k Jan 06, 2023
Script that attempts to force M1 macs into RGB mode when used with monitors that are defaulting to YPbPr.

fix_m1_rgb Script that attempts to force M1 macs into RGB mode when used with monitors that are defaulting to YPbPr. No warranty provided for using th

Kevin Gao 116 Jan 01, 2023
Dataloader tools for language modelling

Installation: pip install lm_dataloader Design Philosophy A library to unify lm dataloading at large scale Simple interface, any tokenizer can be inte

5 Mar 25, 2022
Gesture-Volume-Control - This Python program can adjust the system's volume by using hand gestures

Gesture-Volume-Control This Python program can adjust the system's volume by usi

VatsalAryanBhatanagar 1 Dec 30, 2021
Supplemental Code for "ImpressionNet :A Multi view Approach to Predict Socio Facial Impressions"

Supplemental Code for "ImpressionNet :A Multi view Approach to Predict Socio Facial Impressions" Environment requirement This code is based on Python

Rohan Kumar Gupta 1 Dec 19, 2021
Video Swin Transformer - PyTorch

Video-Swin-Transformer-Pytorch This repo is a simple usage of the official implementation "Video Swin Transformer". Introduction Video Swin Transforme

Haofan Wang 116 Dec 20, 2022
Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction".

TGIN Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction". Files in the folder dataset/ electr

Alibaba 21 Dec 21, 2022
RaftMLP: How Much Can Be Done Without Attention and with Less Spatial Locality?

RaftMLP RaftMLP: How Much Can Be Done Without Attention and with Less Spatial Locality? By Yuki Tatsunami and Masato Taki (Rikkyo University) [arxiv]

Okojo 20 Aug 31, 2022
“Data Augmentation for Cross-Domain Named Entity Recognition” (EMNLP 2021)

Data Augmentation for Cross-Domain Named Entity Recognition Authors: Shuguang Chen, Gustavo Aguilar, Leonardo Neves and Thamar Solorio This repository

<a href=[email protected]"> 18 Sep 10, 2022
Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch

SRDenseNet-pytorch Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch (http://openaccess.thecvf.com/content_ICC

wxy 114 Nov 26, 2022
Compositional and Parameter-Efficient Representations for Large Knowledge Graphs

NodePiece - Compositional and Parameter-Efficient Representations for Large Knowledge Graphs NodePiece is a "tokenizer" for reducing entity vocabulary

Michael Galkin 107 Jan 04, 2023
IEEE Winter Conference on Applications of Computer Vision 2022 Accepted

SSKT(Accepted WACV2022) Concept map Dataset Image dataset CIFAR10 (torchvision) CIFAR100 (torchvision) STL10 (torchvision) Pascal VOC (torchvision) Im

1 Nov 17, 2022
Pyramid addon for OpenAPI3 validation of requests and responses.

Validate Pyramid views against an OpenAPI 3.0 document Peace of Mind The reason this package exists is to give you peace of mind when providing a REST

Pylons Project 79 Dec 30, 2022
Deep Learning Algorithms for Hedging with Frictions

Deep Learning Algorithms for Hedging with Frictions This repository contains the Forward-Backward Stochastic Differential Equation (FBSDE) solver and

Xiaofei Shi 3 Dec 22, 2022
BlueFog Tutorials

BlueFog Tutorials Welcome to the BlueFog tutorials! In this repository, we've put together a collection of awesome Jupyter notebooks. These notebooks

4 Oct 27, 2021
Only valid pull requests will be allowed. Use python only and readme changes will not be accepted.

❌ This repo is excluded from hacktoberfest This repo is for python beginners and contains lot of beginner python projects for practice. You can also s

Prajjwal Pathak 50 Dec 28, 2022
Codes of paper "Unseen Object Amodal Instance Segmentation via Hierarchical Occlusion Modeling"

Unseen Object Amodal Instance Segmentation (UOAIS) Seunghyeok Back, Joosoon Lee, Taewon Kim, Sangjun Noh, Raeyoung Kang, Seongho Bak, Kyoobin Lee This

GIST-AILAB 92 Dec 13, 2022
GitHub repository for "Improving Video Generation for Multi-functional Applications"

Improving Video Generation for Multi-functional Applications GitHub repository for "Improving Video Generation for Multi-functional Applications" Pape

Bernhard Kratzwald 328 Dec 07, 2022