Official PyTorch implementation of PS-KD

Overview

LGCNS AI Research pytorch

Self-Knowledge Distillation with Progressive Refinement of Targets (PS-KD)

Accepted at ICCV 2021, oral presentation

  • Official PyTorch implementation of Self-Knowledge Distillation with Progressive Refinement of Targets (PS-KD).
    [Slides] [Paper] [Video]
  • Kyungyul Kim, ByeongMoon Ji, Doyoung Yoon and Sangheum Hwang

Abstract

The generalization capability of deep neural networks has been substantially improved by applying a wide spectrum of regularization methods, e.g., restricting function space, injecting randomness during training, augmenting data, etc. In this work, we propose a simple yet effective regularization method named progressive self-knowledge distillation (PS-KD), which progressively distills a model's own knowledge to soften hard targets (i.e., one-hot vectors) during training. Hence, it can be interpreted within a framework of knowledge distillation as a student becomes a teacher itself. Specifically, targets are adjusted adaptively by combining the ground-truth and past predictions from the model itself. Please refer to the paper for more details.

Requirements

We have tested the code on the following environments:

  • Python 3.7.7 / Pytorch (>=1.6.0) / torchvision (>=0.7.0)

Datasets

Currently, only CIFAR-100, ImageNet dataset is supported.

#) To verify the effectivness of PS-KD on Detection task and Machine translation task, we used

  • For object detection: Pascal VOC
  • For machine translation: IWSLT 15 English-German / German-English, Multi30k.
  • (Please refer to the paper for more details)

How to Run

Single-node & Multi-GPU Training

To train a single model with 1 nodes & multi-GPU, run the command as follows:

$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 0 \
                  --world_size 1 \
                  --multiprocessing_distributed True

Multi-node Training

To train a single model with 2 nodes, for instance, run the commands below in sequence:

# on the node #0
$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 0 \
                  --world_size 2 \
                  --dist_url tcp://{master_ip}:{master_port} \
                  --multiprocessing_distributed
# on the node #1
$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 1 \
                  --world_size 2 \
                  --dist_url tcp://{master_ip}:{master_port} \
                  --multiprocessing_distributed

Saving & Loading Checkpoints

Saved Filenames

  • save_dir will be automatically determined(with sequential number suffixes) unless otherwise designated.
  • Model's checkpoints are saved in ./{experiments_dir}/models/checkpoint_{epoch}.pth.
  • The best checkpoints are saved in ./{experiments_dir}/models/checkpoint_best.pth.

Loading Checkpoints (resume)

  • Pass model path as a --resume argument

Experimental Results

Performance measures

  • Top-1 Error / Top-5 Error
  • Negative Log Likelihood (NLL)
  • Expected Calibration Error (ECE)
  • Area Under the Risk-coverage Curve (AURC)

Results on CIFAR-100

Model + Method Dataset Top-1 Error Top-5 Error NLL ECE AURC
PreAct ResNet-18 (baseline) CIFAR-100 24.18 6.90 1.10 11.84 67.65
PreAct ResNet-18 + Label Smoothing CIFAR-100 20.94 6.02 0.98 10.79 57.74
PreAct ResNet-18 + CS-KD [CVPR'20] CIFAR-100 21.30 5.70 0.88 6.24 56.56
PreAct ResNet-18 + TF-KD [CVPR'20] CIFAR-100 22.88 6.01 1.05 11.96 61.77
PreAct ResNet-18 + PS-KD CIFAR-100 20.82 5.10 0.76 1.77 52.10
PreAct ResNet-101 (baseline) CIFAR-100 20.75 5.28 0.89 10.02 55.45
PreAct ResNet-101 + Label Smoothing CIFAR-100 19.84 5.07 0.93 3.43 95.76
PreAct ResNet-101 + CS-KD [CVPR'20] CIFAR-100 20.76 5.62 1.02 12.18 64.44
PreAct ResNet-101 + TF-KD [CVPR'20] CIFAR-100 20.13 5.10 0.84 6.14 58.8
PreAct ResNet-101 + PS-KD CIFAR-100 19.43 4.30 0.74 6.92 49.01
DenseNet-121 (baseline) CIFAR-100 20.05 4.99 0.82 7.34 52.21
DenseNet-121 + Label Smoothing CIFAR-100 19.80 5.46 0.92 3.76 91.06
DenseNet-121 + CS-KD [CVPR'20] CIFAR-100 20.47 6.21 1.07 13.80 73.37
DenseNet-121 + TF-KD [CVPR'20] CIFAR-100 19.88 5.10 0.85 7.33 69.23
DenseNet-121 + PS-KD CIFAR-100 18.73 3.90 0.69 3.71 45.55
ResNeXt-29 (baseline) CIFAR-100 18.65 4.47 0.74 4.17 44.27
ResNeXt-29 + Label Smoothing CIFAR-100 17.60 4.23 1.05 22.14 41.92
ResNeXt-29 + CS-KD [CVPR'20] CIFAR-100 18.26 4.37 0.80 5.95 42.11
ResNeXt-29 + TF-KD [CVPR'20] CIFAR-100 17.33 3.87 0.74 6.73 40.34
ResNeXt-29 + PS-KD CIFAR-100 17.28 3.60 0.72 9.18 40.19
PyramidNet-200 (baseline) CIFAR-100 16.80 3.69 0.73 8.04 36.95
PyramidNet-200 + Label Smoothing CIFAR-100 17.82 4.72 0.89 3.46 105.02
PyramidNet-200 + CS-KD [CVPR'20] CIFAR-100 18.31 5.70 1.17 14.70 70.05
PyramidNet-200 + TF-KD [CVPR'20] CIFAR-100 16.48 3.37 0.79 10.48 37.04
PyramidNet-200 + PS-KD CIFAR-100 15.49 3.08 0.56 1.83 32.14

Results on ImageNet

Model +Method Dataset Top-1 Error Top-5 Error NLL ECE AURC
DenseNet-264* ImageNet 22.15 6.12 -- -- --
ResNet-152 ImageNet 22.19 6.19 0.88 3.84 61.79
ResNet-152 + Label Smoothing ImageNet 21.73 5.85 0.92 3.91 68.24
ResNet-152 + CS-KD [CVPR'20] ImageNet 21.61 5.92 0.90 5.79 62.12
ResNet-152 + TF-KD [CVPR'20] ImageNet 22.76 6.43 0.91 4.70 65.28
ResNet-152 + PS-KD ImageNet 21.41 5.86 0.84 2.51 61.01

* denotes results reported in the original papers

Citation

If you find this repository useful, please consider giving a star and citation PS-KD:

@InProceedings{Kim_2021_ICCV,
    author    = {Kim, Kyungyul and Ji, ByeongMoon and Yoon, Doyoung and Hwang, Sangheum},
    title     = {Self-Knowledge Distillation With Progressive Refinement of Targets},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {6567-6576}
}

Contact for Issues

License

Copyright (c) 2021-present LG CNS Corp.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
Owner
Open source repository of LG CNS AI Research (LAIR), LG
A pytorch-version implementation codes of paper: "BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation"

BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation A pytorch-version implementation

11 Oct 08, 2022
《Unsupervised 3D Human Pose Representation with Viewpoint and Pose Disentanglement》(ECCV 2020) GitHub: [fig9]

Unsupervised 3D Human Pose Representation [Paper] The implementation of our paper Unsupervised 3D Human Pose Representation with Viewpoint and Pose Di

42 Nov 24, 2022
This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing described in paper Discontinuous Grammar as a Foreign Language.

Discontinuous Grammar as a Foreign Language This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing

Daniel Fernández-González 2 Apr 07, 2022
Breaching - Breaching privacy in federated learning scenarios for vision and text

Breaching - A Framework for Attacks against Privacy in Federated Learning This P

Jonas Geiping 139 Jan 03, 2023
Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES)

Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES) This repo contains the full NITRATES pipeline for maximum likelihood-driven discov

13 Nov 08, 2022
Boston House Prediction Valuation Tool

Boston-House-Prediction-Valuation-Tool From Below Anlaysis The Valuation Tool is Designed Correlation Matrix Regrssion Analysis Between Target Vs Pred

0 Sep 09, 2022
Simulate genealogical trees and genomic sequence data using population genetic models

msprime msprime is a population genetics simulator based on tskit. Msprime can simulate random ancestral histories for a sample of individuals (consis

Tskit developers 150 Dec 14, 2022
This code finds bounding box of a single human mouth.

This code finds bounding box of a single human mouth. In comparison to other face segmentation methods, it is relatively insusceptible to open mouth conditions, e.g., yawning, surgical robots, etc. T

iThermAI 4 Nov 27, 2022
[ICCV 2021] Focal Frequency Loss for Image Reconstruction and Synthesis

Focal Frequency Loss - Official PyTorch Implementation This repository provides the official PyTorch implementation for the following paper: Focal Fre

Liming Jiang 460 Jan 04, 2023
Automatically download the cwru data set, and then divide it into training data set and test data set

Automatically download the cwru data set, and then divide it into training data set and test data set.自动下载cwru数据集,然后分训练数据集和测试数据集

6 Jun 27, 2022
A library for low-memory inferencing in PyTorch.

Pylomin Pylomin (PYtorch LOw-Memory INference) is a library for low-memory inferencing in PyTorch. Installation ... Usage For example, the following c

3 Oct 26, 2022
This repo contains research materials released by members of the Google Brain team in Tokyo.

Brain Tokyo Workshop 🧠 🗼 This repo contains research materials released by members of the Google Brain team in Tokyo. Past Projects Weight Agnostic

Google 1.2k Jan 02, 2023
Net2net - Network-to-Network Translation with Conditional Invertible Neural Networks

Net2Net Code accompanying the NeurIPS 2020 oral paper Network-to-Network Translation with Conditional Invertible Neural Networks Robin Rombach*, Patri

CompVis Heidelberg 206 Dec 20, 2022
Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel order of RGB and BGR. Simple Channel Converter for ONNX.

scc4onnx Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel

Katsuya Hyodo 16 Dec 22, 2022
Official git for "CTAB-GAN: Effective Table Data Synthesizing"

CTAB-GAN This is the official git paper CTAB-GAN: Effective Table Data Synthesizing. The paper is published on Asian Conference on Machine Learning (A

30 Dec 26, 2022
Collect some papers about transformer with vision. Awesome Transformer with Computer Vision (CV)

Awesome Visual-Transformer Collect some Transformer with Computer-Vision (CV) papers. If you find some overlooked papers, please open issues or pull r

dkliang 2.8k Jan 08, 2023
Docker containers of baseline agents for the Crafter environment

Crafter Baselines This repository contains Docker containers for running various baselines on the Crafter environment. Reward Agents DreamerV2 based o

Danijar Hafner 17 Sep 25, 2022
RRL: Resnet as representation for Reinforcement Learning

Resnet as representation for Reinforcement Learning (RRL) is a simple yet effective approach for training behaviors directly from visual inputs. We demonstrate that features learned by standard image

Meta Research 21 Dec 07, 2022
Voice of Pajlada with model and weights.

Pajlada TTS Stripped down version of ForwardTacotron (https://github.com/as-ideas/ForwardTacotron) with pretrained weights for Pajlada's (https://gith

6 Sep 03, 2021
Python package to generate image embeddings with CLIP without PyTorch/TensorFlow

imgbeddings A Python package to generate embedding vectors from images, using OpenAI's robust CLIP model via Hugging Face transformers. These image em

Max Woolf 81 Jan 04, 2023