Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Overview

Attention Probe: Vision Transformer Distillation in the Wild

License: MIT

Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang
In ICASSP 2022

This code is the Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Overview

  • We propose the concept of Attention Probe, a special section of the attention map to utilize a large amount of unlabeled data in the wild to complete the vision transformer data-free distillation task. Instead of generating images from the teacher network with a series of priori, images most relevant to the given pre-trained network and tasks will be identified from a large unlabeled dataset (e.g., Flickr) to conduct the knowledge distillation task.
  • We propose a simple yet efficient distillation algorithm, called probe distillation, to distill the student model using intermediate features of the teacher model, which is based on the Attention Probe.

Prerequisite

We use Pytorch 1.7.1, and CUDA 11.0. You can install them with

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

It should also be applicable to other Pytorch and CUDA versions.

Usage

Data Preparation

First, you need to modify the storage format of the cifar-10/100 and tinyimagenet dataset to the style of ImageNet, etc. CIFAR 10 run:

python process_cifar10.py

CIFAR 100 run:

python process_cifar100.py

Tiny-ImageNet run:

python process_tinyimagenet.py
python process_move_file.py

The dataset dir should have the following structure:

dir/
  train/
    ...
  val/
    n01440764/
      ILSVRC2012_val_00000293.JPEG
      ...
    ...

Train a normal teacher network

For this step you need to train normal teacher transformer models for selecting valuable data from the wild. We train the teacher model based on the timm PyTorch library:

timm

Our pretrained teacher models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here:

Pretrained teacher models

Select valuable data from the wild

Then, you can use the Attention Probe method to select valuable data in the wild dataset.

To select valuable data CIFAR-10 run:

bash training.sh
(CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

CIFAR-100 run:

bash training.sh
(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

TinyImageNet run:

bash training_tinyimagenet.sh

ImageNet run:

bash training_imagenet.sh

After you will get "class_weights.pth, pred_out.pth, value_blk3.pth, value_blk7.pth, value_out.pth" in '/selected/cifar10/' or '/selected/cifar100/' directory, you have already obtained the selected data.

Probe Knowledge Distillation for Student networks

Then you can distill the student model using intermediate features of the teacher model based on the selected data.

bash training.sh
(CIFAR 10 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0,1,2,3 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

TinyImageNet run:

bash training_tinyimagenet.sh

ImageNet run:

bash training_imagenet.sh

you will get the student transformer model in '/output/cifar10/student/' or '/output/cifar100/student/' directory.

Our distilled student models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here: Distilled student models

Results

Citation

@inproceedings{
wang2022attention,
title={Attention Probe: Vision Transformer Distillation in the Wild},
author={Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang},
booktitle={International Conference on Acoustics, Speech and Signal Processing},
year={2022},
url={https://2022.ieeeicassp.org/}
}

Acknowledgement

Owner
IIGROUP
The Intelligent Interaction Group at Tsinghua University
IIGROUP
EfficientNetv2 TensorRT int8

EfficientNetv2_TensorRT_int8 EfficientNetv2模型实现来自https://github.com/d-li14/efficientnetv2.pytorch 环境配置 ubuntu:18.04 cuda:11.0 cudnn:8.0 tensorrt:7

34 Apr 24, 2022
Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch

Next Word Prediction Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch 🎬 Project Demo ✔ Application is hosted on Streamlit. You can see t

Vivek7 3 Aug 26, 2022
RSC-Net: 3D Human Pose, Shape and Texture from Low-Resolution Images and Videos

RSC-Net: 3D Human Pose, Shape and Texture from Low-Resolution Images and Videos Implementation for "3D Human Pose, Shape and Texture from Low-Resoluti

XiangyuXu 42 Nov 10, 2022
Roger Labbe 13k Dec 29, 2022
YOLO5Face: Why Reinventing a Face Detector (https://arxiv.org/abs/2105.12931)

Introduction Yolov5-face is a real-time,high accuracy face detection. Performance Single Scale Inference on VGA resolution(max side is equal to 640 an

DeepCam Shenzhen 1.4k Jan 07, 2023
Tooling for converting STAC metadata to ODC data model

手语识别 0、使用到的模型 (1). openpose,作者:CMU-Perceptual-Computing-Lab https://github.com/CMU-Perceptual-Computing-Lab/openpose (2). 图像分类classification,作者:Bubbl

Open Data Cube 65 Dec 20, 2022
A PyTorch implementation of "SimGNN: A Neural Network Approach to Fast Graph Similarity Computation" (WSDM 2019).

SimGNN ⠀⠀⠀ A PyTorch implementation of SimGNN: A Neural Network Approach to Fast Graph Similarity Computation (WSDM 2019). Abstract Graph similarity s

Benedek Rozemberczki 534 Dec 25, 2022
A short code in python, Enchpyter, is able to encrypt and decrypt words as you determine, of course

Enchpyter Enchpyter is a program do encrypt and decrypt any word you want (just letters). You enter how many letters jumps and write the word, so, the

João Assalim 2 Oct 10, 2022
Next-Best-View Estimation based on Deep Reinforcement Learning for Active Object Classification

next_best_view_rl Setup Clone the repository: git clone --recurse-submodules ... In 'third_party/zed-ros-wrapper': git checkout devel Install mujoco `

Christian Korbach 1 Feb 15, 2022
Source Code for ICSE 2022 Paper - ``Can We Achieve Fairness Using Semi-Supervised Learning?''

Fair-SSL Source Code for ICSE 2022 Paper - Can We Achieve Fairness Using Semi-Supervised Learning? Ethical bias in machine learning models has become

1 Dec 18, 2021
Official repository for "Restormer: Efficient Transformer for High-Resolution Image Restoration". SOTA for motion deblurring, image deraining, denoising (Gaussian/real data), and defocus deblurring.

Restormer: Efficient Transformer for High-Resolution Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan,

Syed Waqas Zamir 906 Dec 30, 2022
Natural Intelligence is still a pretty good idea.

Human Learn Machine Learning models should play by the rules, literally. Project Goal Back in the old days, it was common to write rule-based systems.

vincent d warmerdam 641 Dec 26, 2022
Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Sami BARCHID 2 Oct 20, 2022
The repo of the preprinting paper "Labels Are Not Perfect: Inferring Spatial Uncertainty in Object Detection"

Inferring Spatial Uncertainty in Object Detection A teaser version of the code for the paper Labels Are Not Perfect: Inferring Spatial Uncertainty in

ZINING WANG 21 Mar 03, 2022
A python module for scientific analysis of 3D objects based on VTK and Numpy

A lightweight and powerful python module for scientific analysis and visualization of 3d objects.

Marco Musy 1.5k Jan 06, 2023
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

223 Dec 17, 2022
Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships.

feature-set-comp Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships. Reposito

Trent Henderson 7 May 25, 2022
State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Fredrik Carlsson 88 Dec 30, 2022
Stacked Hourglass Network with a Multi-level Attention Mechanism: Where to Look for Intervertebral Disc Labeling

⚠️ ‎‎‎ A more recent and actively-maintained version of this code is available in ivadomed Stacked Hourglass Network with a Multi-level Attention Mech

Reza Azad 14 Oct 24, 2022
ConvMixer unofficial implementation

ConvMixer ConvMixer 非官方实现 pytorch 版本已经实现。 nets 是重构版本 ,test 是官方代码 感兴趣小伙伴可以对照看一下。 keras 已经实现 tf2.x 中 是tensorflow 2 版本 gelu 激活函数要求 tf=2.4 否则使用入下代码代替gelu

Jian Tengfei 8 Jul 11, 2022