TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

Overview

FunMatch-Distillation

TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

The techniques have been demonstrated using three datasets:

This repository provides Kaggle Kernel notebooks so that we can leverage the free TPu v3-8 to run the long training schedules. Please refer to this section.

Importance

The importance of knowledge distillation lies in its practical usefulness. With the recipes from "function matching", we can now perform knowledge distillation using a principled approach yielding student models that can actually match the performance of their teacher models. This essentially allows us to compress bigger models into (much) smaller ones thereby reducing storage costs and improving inference speed.

Key ingredients

  • No use of ground-truth labels during distillation.
  • Teacher and student should see same images during distillation as opposed to differently augmented views of same images.
  • Aggressive form of MixUp as the key augmentation recipe. MixUp is paired with "Inception-style" cropping (implemented in this script).
  • A LONG training schedule for distillation. At least 1000 epochs to get good results without overfitting. The importance of a long training schedule is paramount as studied in the paper.

Results

The table below summarizes the results of my experiments. In all cases, teacher is a BiT-ResNet101x3 model and student is a BiT-ResNet50x1. For fun, you can also try to distill into other model families. BiT stands for "Big Transfer" and it was proposed in this paper.

Dataset Teacher/Student Top-1 Acc on Test Location
Flowers102 Teacher 98.18% Link
Flowers102 Student (1000 epochs) 81.02% Link
Pet37 Teacher 90.92% Link
Pet37 Student (300 epochs) 81.3% Link
Pet37 Student (1000 epochs) 86% Link
Food101 Teacher 85.52% Link
Food101 Student (100 epochs) 76.06% Link

(Location denotes the trained model location.)

These results are consistent with Table 4 of the original paper.

It should be noted that none of the above student training regimes showed signs of overfitting. Further improvements can be done by training for longer. The authors also showed that Shampoo can get to similar performance much quicker than Adam during distillation. So, it may very well be possible to get this performance with fewer epochs with Shampoo.

A few differences from the original implementation:

  • The authors use BiT-ResNet152x2 as a teacher.
  • The mixup() variant I used will produce a pair of duplicate images if the number of images is even. Now, for 8 workers it will become 8 pairs. This may have led to the reduced performance. We can overcome this by using tf.roll(images, 1, axis=0) instead of tf.reverse in the mixup() function. Thanks to Lucas Beyer for pointing this out.

About the notebooks

All the notebooks are fully runnable on Kaggle Kernel. The only requirement is that you'd need a billing enabled GCP account to use GCS Buckets to store data.

Notebook Description Kaggle Kernel
train_bit.ipynb Shows how to train the teacher model. Link
train_bit_keras_tuner.ipynb Shows how to run hyperparameter tuning using
Keras Tuner for the teacher model.
Link
funmatch_distillation.ipynb Shows an implementation of the recipes
from "function matching".
Link

These are only demonstrated on the Pet37 dataset but will work out-of-the-box for the other datasets too.

TFRecords

For convenience, TFRecords of different datasets are provided:

Dataset TFRecords
Flowers102 Link
Pet37 Link
Food101 Link

Paper citation

@misc{beyer2021knowledge,
      title={Knowledge distillation: A good teacher is patient and consistent}, 
      author={Lucas Beyer and Xiaohua Zhai and Amélie Royer and Larisa Markeeva and Rohan Anil and Alexander Kolesnikov},
      year={2021},
      eprint={2106.05237},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgements

Huge thanks to Lucas Beyer (first author of the paper) for providing suggestions on the initial version of the implementation.

Thanks to the ML-GDE program for providing GCP credits.

Thanks to TRC for providing Cloud TPU access.

You might also like...
Implementation of momentum^2 teacher

Momentum^2 Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning Requirements All experiments are done with python3.6, torch

Code implementation of Data Efficient Stagewise Knowledge Distillation paper.
Code implementation of Data Efficient Stagewise Knowledge Distillation paper.

Data Efficient Stagewise Knowledge Distillation Table of Contents Data Efficient Stagewise Knowledge Distillation Table of Contents Requirements Image

The official implementation of CVPR 2021 Paper: Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation.

Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation This repository is the official implementation of CVPR 2021 paper:

PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.
PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Official implementation of the paper
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction 🥼 Predicts the likelihood of Polycystic Ovary Syndrome based on

[ICLR 2021 Spotlight Oral] "Undistillable: Making A Nasty Teacher That CANNOT teach students", Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang

Undistillable: Making A Nasty Teacher That CANNOT teach students "Undistillable: Making A Nasty Teacher That CANNOT teach students" Haoyu Ma, Tianlong

Unet network with mean teacher for altrasound image segmentation

Unet network with mean teacher for altrasound image segmentation

Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Releases(v4.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
ML course - EPFL Machine Learning Course, Fall 2021

EPFL Machine Learning Course CS-433 Machine Learning Course, Fall 2021 Repository for all lecture notes, labs and projects - resources, code templates

EPFL Machine Learning and Optimization Laboratory 1k Jan 04, 2023
(CVPR2021) ClassSR: A General Framework to Accelerate Super-Resolution Networks by Data Characteristic

ClassSR (CVPR2021) ClassSR: A General Framework to Accelerate Super-Resolution Networks by Data Characteristic Paper Authors: Xiangtao Kong, Hengyuan

Xiangtao Kong 308 Jan 05, 2023
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

ELECTRA Introduction ELECTRA is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using

Google Research 2.1k Dec 28, 2022
Pytorch Lightning Distributed Accelerators using Ray

Distributed PyTorch Lightning Training on Ray This library adds new PyTorch Lightning accelerators for distributed training using the Ray distributed

166 Dec 27, 2022
DrWhy is the collection of tools for eXplainable AI (XAI). It's based on shared principles and simple grammar for exploration, explanation and visualisation of predictive models.

Responsible Machine Learning With Great Power Comes Great Responsibility. Voltaire (well, maybe) How to develop machine learning models in a responsib

Model Oriented 590 Dec 26, 2022
Neural Scene Flow Fields using pytorch-lightning, with potential improvements

nsff_pl Neural Scene Flow Fields using pytorch-lightning. This repo reimplements the NSFF idea, but modifies several operations based on observation o

AI葵 178 Dec 21, 2022
StarGAN - Official PyTorch Implementation (CVPR 2018)

StarGAN - Official PyTorch Implementation ***** New: StarGAN v2 is available at https://github.com/clovaai/stargan-v2 ***** This repository provides t

Yunjey Choi 5.1k Jan 04, 2023
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
Pytorch implementation of paper: "NeurMiPs: Neural Mixture of Planar Experts for View Synthesis"

NeurMips: Neural Mixture of Planar Experts for View Synthesis This is the official repo for PyTorch implementation of paper "NeurMips: Neural Mixture

James Lin 101 Dec 13, 2022
Code for our paper "Multi-scale Guided Attention for Medical Image Segmentation"

Medical Image Segmentation with Guided Attention This repository contains the code of our paper: "'Multi-scale self-guided attention for medical image

Ashish Sinha 394 Dec 28, 2022
Source code of our work: "Benchmarking Deep Models for Salient Object Detection"

SALOD Source code of our work: "Benchmarking Deep Models for Salient Object Detection". In this works, we propose a new benchmark for SALient Object D

22 Dec 30, 2022
PyTorch implementation of hand mesh reconstruction described in CMR and MobRecon.

Hand Mesh Reconstruction Introduction This repo is the PyTorch implementation of hand mesh reconstruction described in CMR and MobRecon. Update 2021-1

Xingyu Chen 236 Dec 29, 2022
SOTA easy to use PyTorch-based DL training library

Easily train or fine-tune SOTA computer vision models from one training repository. SuperGradients Introduction Welcome to SuperGradients, a free open

619 Jan 03, 2023
Wordplay, an artificial Intelligence based crossword puzzle solver.

Wordplay, AI based crossword puzzle solver A crossword is a word puzzle that usually takes the form of a square or a rectangular grid of white- and bl

Vaibhaw 4 Nov 16, 2022
[NeurIPS-2021] Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data

MosaicKD Code for NeurIPS-21 paper "Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data" 1. Motivation Natural images share common l

ZJU-VIPA 37 Nov 10, 2022
Resources related to our paper "CLIN-X: pre-trained language models and a study on cross-task transfer for concept extraction in the clinical domain"

CLIN-X (CLIN-X-ES) & (CLIN-X-EN) This repository holds the companion code for the system reported in the paper: "CLIN-X: pre-trained language models a

Bosch Research 4 Dec 05, 2022
Aircraft design optimization made fast through modern automatic differentiation

Aircraft design optimization made fast through modern automatic differentiation. Plug-and-play analysis tools for aerodynamics, propulsion, structures, trajectory design, and much more.

Peter Sharpe 394 Dec 23, 2022
The code for replicating the experiments from the LFI in SSMs with Unknown Dynamics paper.

Likelihood-Free Inference in State-Space Models with Unknown Dynamics This package contains the codes required to run the experiments in the paper. Th

Alex Aushev 0 Dec 27, 2021
Deep Convolutional Generative Adversarial Networks

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks Alec Radford, Luke Metz, Soumith Chintala All images in t

Alec Radford 3.4k Dec 29, 2022
Hyperparameter tuning for humans

KerasTuner KerasTuner is an easy-to-use, scalable hyperparameter optimization framework that solves the pain points of hyperparameter search. Easily c

Keras 2.6k Dec 27, 2022