Code implementation of Data Efficient Stagewise Knowledge Distillation paper.

Overview

Data Efficient Stagewise Knowledge Distillation

Stagewise Training Procedure

Table of Contents

This repository presents the code implementation for Stagewise Knowledge Distillation, a technique for improving knowledge transfer between a teacher model and student model.

Requirements

  • Install the dependencies using conda with the requirements.yml file
    conda env create -f environment.yml
    
  • Setup the stagewise-knowledge-distillation package itself
    pip install -e .
    
  • Apart from the above mentioned dependencies, it is recommended to have an Nvidia GPU (CUDA compatible) with at least 8 GB of video memory (most of the experiments will work with 6 GB also). However, the code works with CPU only machines as well.

Image Classification

Introduction

In this work, ResNet architectures are used. Particularly, we used ResNet10, 14, 18, 20 and 26 as student networks and ResNet34 as the teacher network. The datasets used are CIFAR10, Imagenette and Imagewoof. Note that Imagenette and Imagewoof are subsets of ImageNet.

Preparation

  • Before any experiments, you need to download the data and saved weights of teacher model to appropriate locations.

  • The following script

    • downloads the datasets
    • saves 10%, 20%, 30% and 40% splits of each dataset separately
    • downloads teacher model weights for all 3 datasets
    # assuming you are in the root folder of the repository
    cd image_classification/scripts
    bash setup.sh
    

Experiments

For detailed information on the various experiments, refer to the paper. In all the image classification experiments, the following common training arguments are listed with the possible values they can take:

  • dataset (-d) : imagenette, imagewoof, cifar10
  • model (-m) : resnet10, resnet14, resnet18, resnet20, resnet26, resnet34
  • number of epochs (-e) : Integer is required
  • percentage of dataset (-p) : 10, 20, 30, 40 (don't use this argument at all for full dataset experiments)
  • random seed (-s) : Give any random seed (for reproducibility purposes)
  • gpu (-g) : Don't use unless training on CPU (in which case, use -g 'cpu' as the argument). In case of multi-GPU systems, run CUDA_VISIBLE_DEVICES=id in the terminal before the experiment, where id is the ID of your GPU according to nvidia-smi output.
  • Comet ML API key (-a) (optional) : If you want to use Comet ML for tracking your experiments, then either put your API key as the argument or make it the default argument in the arguments.py file. Otherwise, no need of using this argument.
  • Comet ML workspace (-w) (optional) : If you want to use Comet ML for tracking your experiments, then either put your workspace name as the argument or make it the default argument in the arguments.py file. Otherwise, no need of using this argument.

In the following subsections, example commands for training are given for one experiment each.

No Teacher

Full Imagenette dataset, ResNet10

python3 no_teacher.py -d imagenette -m resnet10 -e 100 -s 0

Traditional KD (FitNets)

20% Imagewoof dataset, ResNet18

python3 traditional_kd.py -d imagewoof -m resnet18 -p 20 -e 100 -s 0

FSP KD

30% CIFAR10 dataset, ResNet14

python3 fsp_kd.py -d cifar10 -m resnet14 -p 30 -e 100 -s 0

Attention Transfer KD

10% Imagewoof dataset, ResNet26

python3 attention_transfer_kd.py -d imagewoof -m resnet26 -p 10 -e 100 -s 0

Hinton KD

Full CIFAR10 dataset, ResNet14

python3 hinton_kd.py -d cifar10 -m resnet14 -e 100 -s 0

Simultaneous KD (Proposed Baseline)

40% Imagenette dataset, ResNet20

python3 simultaneous_kd.py -d imagenette -m resnet20 -p 40 -e 100 -s 0

Stagewise KD (Proposed Method)

Full CIFAR10 dataset, ResNet10

python3 stagewise_kd.py -d cifar10 -m resnet10 -e 100 -s 0

Semantic Segmentation

Introduction

In this work, ResNet backbones are used to construct symmetric U-Nets for semantic segmentation. Particularly, we used ResNet10, 14, 18, 20 and 26 as the backbones for student networks and ResNet34 as the backbone for the teacher network. The dataset used is the Cambridge-driving Labeled Video Database (CamVid).

Preparation

  • The following script
    • downloads the data (and shifts it to appropriate folder)
    • saves 10%, 20%, 30% and 40% splits of each dataset separately
    • downloads the pretrained teacher weights in appropriate folder
    # assuming you are in the root folder of the repository
    cd semantic_segmentation/scripts
    bash setup.sh
    

Experiments

For detailed information on the various experiments, refer to the paper. In all the semantic segmentation experiments, the following common training arguments are listed with the possible values they can take:

  • dataset (-d) : camvid
  • model (-m) : resnet10, resnet14, resnet18, resnet20, resnet26, resnet34
  • number of epochs (-e) : Integer is required
  • percentage of dataset (-p) : 10, 20, 30, 40 (don't use this argument at all for full dataset experiments)
  • random seed (-s) : Give any random seed (for reproducibility purposes)
  • gpu (-g) : Don't use unless training on CPU (in which case, use -g 'cpu' as the argument). In case of multi-GPU systems, run CUDA_VISIBLE_DEVICES=id in the terminal before the experiment, where id is the ID of your GPU according to nvidia-smi output.
  • Comet ML API key (-a) (optional) : If you want to use Comet ML for tracking your experiments, then either put your API key as the argument or make it the default argument in the arguments.py file. Otherwise, no need of using this argument.
  • Comet ML workspace (-w) (optional) : If you want to use Comet ML for tracking your experiments, then either put your workspace name as the argument or make it the default argument in the arguments.py file. Otherwise, no need of using this argument.

Note: Currently, there are no plans for adding Attention Transfer KD and FSP KD experiments for semantic segmentation.

In the following subsections, example commands for training are given for one experiment each.

No Teacher

Full CamVid dataset, ResNet10

python3 pretrain.py -d camvid -m resnet10 -e 100 -s 0

Traditional KD (FitNets)

20% CamVid dataset, ResNet18

python3 traditional_kd.py -d camvid -m resnet18 -p 20 -e 100 -s 0

Simultaneous KD (Proposed Baseline)

40% CamVid dataset, ResNet20

python3 simultaneous_kd.py -d camvid -m resnet20 -p 40 -e 100 -s 0

Stagewise KD (Proposed Method)

10 % CamVid dataset, ResNet10

python3 stagewise_kd.py -d camvid -m resnet10 -p 10 -e 100 -s 0

Citation

If you use this code or method in your work, please cite using

@misc{kulkarni2020data,
      title={Data Efficient Stagewise Knowledge Distillation}, 
      author={Akshay Kulkarni and Navid Panchi and Sharath Chandra Raparthy and Shital Chiddarwar},
      year={2020},
      eprint={1911.06786},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Built by Akshay Kulkarni, Navid Panchi and Sharath Chandra Raparthy.

Owner
IvLabs
Robotics and AI community of VNIT
IvLabs
Code for SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021)

SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021) SyncTwin is a treatment effect estimation method tailored for observat

Zhaozhi Qian 3 Nov 03, 2022
Reviving Iterative Training with Mask Guidance for Interactive Segmentation

This repository provides the source code for training and testing state-of-the-art click-based interactive segmentation models with the official PyTorch implementation

Visual Understanding Lab @ Samsung AI Center Moscow 406 Jan 01, 2023
PED: DETR for Crowd Pedestrian Detection

PED: DETR for Crowd Pedestrian Detection Code for PED: DETR For (Crowd) Pedestrian Detection Paper PED: DETR for Crowd Pedestrian Detection Installati

36 Sep 13, 2022
This repository contains the implementation of the following paper: Cross-Descriptor Visual Localization and Mapping

Cross-Descriptor Visual Localization and Mapping This repository contains the implementation of the following paper: "Cross-Descriptor Visual Localiza

Mihai Dusmanu 81 Oct 06, 2022
Unofficial implementation of MLP-Mixer: An all-MLP Architecture for Vision

MLP-Mixer: An all-MLP Architecture for Vision This repo contains PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision. Usage : impo

Rishikesh (ऋषिकेश) 175 Dec 23, 2022
Code for MSc Quantitative Finance Dissertation

MSc Dissertation Code ReadMe Sector Volatility Prediction Performance Using GARCH Models and Artificial Neural Networks Curtis Nybo MSc Quantitative F

2 Dec 01, 2022
Lightweight plotting to the terminal. 4x resolution via Unicode.

Uniplot Lightweight plotting to the terminal. 4x resolution via Unicode. When working with production data science code it can be handy to have plotti

Olav Stetter 203 Dec 29, 2022
PyTorch implementation of TSception V2 using DEAP dataset

TSception This is the PyTorch implementation of TSception V2 using DEAP dataset in our paper: Yi Ding, Neethu Robinson, Su Zhang, Qiuhao Zeng, Cuntai

Yi Ding 27 Dec 15, 2022
Contextualized Perturbation for Textual Adversarial Attack, NAACL 2021

Contextualized Perturbation for Textual Adversarial Attack Introduction This is a PyTorch implementation of Contextualized Perturbation for Textual Ad

cookielee77 30 Jan 01, 2023
Structured Data Gradient Pruning (SDGP)

Structured Data Gradient Pruning (SDGP) Weight pruning is a technique to make Deep Neural Network (DNN) inference more computationally efficient by re

Bradley McDanel 10 Nov 11, 2022
Basit bir burç modülü.

Bu modulu burclar hakkinda gundelik bir sekilde bilgi alin diye yaptim ve sizler icin kullanima sunuyorum. Modulun kullanimi asiri basit: Ornek Kullan

Special 17 Jun 08, 2022
Official repository for the paper F, B, Alpha Matting

FBA Matting Official repository for the paper F, B, Alpha Matting. This paper and project is under heavy revision for peer reviewed publication, and s

Marco Forte 404 Jan 05, 2023
A Model for Natural Language Attack on Text Classification and Inference

TextFooler A Model for Natural Language Attack on Text Classification and Inference This is the source code for the paper: Jin, Di, et al. "Is BERT Re

Di Jin 418 Dec 16, 2022
Official Pytorch implementation of 6DRepNet: 6D Rotation representation for unconstrained head pose estimation.

6D Rotation Representation for Unconstrained Head Pose Estimation (Pytorch) Paper Thorsten Hempel and Ahmed A. Abdelrahman and Ayoub Al-Hamadi, "6D Ro

Thorsten Hempel 284 Dec 23, 2022
YOLOX + ROS(1, 2) object detection package

YOLOX + ROS(1, 2) object detection package

Ar-Ray 158 Dec 21, 2022
Seeing if I can put together an interactive version of 3b1b's Manim in Streamlit

streamlit-manim Seeing if I can put together an interactive version of 3b1b's Manim in Streamlit Installation I had to install pango with sudo apt-get

Adrien Treuille 6 Aug 03, 2022
CM building dataset Timisoara

CM_building_dataset_Timisoara Date created: Febr-2020 The Timi\c{s}oara Building Dataset - TMBuD - is composed of 160 images with the resolution of 76

Orhei Ciprian 5 Sep 07, 2022
Official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch.

Multi-speaker DGP This repository provides official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch. O

sarulab-speech 24 Sep 07, 2022
[CVPR 2021] Teachers Do More Than Teach: Compressing Image-to-Image Models (CAT)

CAT arXiv Pytorch implementation of our method for compressing image-to-image models. Teachers Do More Than Teach: Compressing Image-to-Image Models Q

Snap Research 160 Dec 09, 2022
Understanding Convolution for Semantic Segmentation

TuSimple-DUC by Panqu Wang, Pengfei Chen, Ye Yuan, Ding Liu, Zehua Huang, Xiaodi Hou, and Garrison Cottrell. Introduction This repository is for Under

TuSimple 585 Dec 31, 2022