[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

Overview

MDCA Calibration

This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration".

Abstract

Deep Neural Networks (DNNs) make overconfident mistakes which can prove to be probematic in deployment in safety critical applications. Calibration is aimed to enhance trust in DNNs. The goal of our proposed Multi-Class Difference in Confidence and Accuracy (MDCA) loss is to align the probability estimates in accordance with accuracy thereby enhancing the trust in DNN decisions. MDCA can be used in case of image classification, image segmentation, and natural language classification tasks.

Teaser

Above image shows comparison of classwise reliability diagrams of Cross-Entropy vs. our proposed method.

Requirements

  • Python 3.8
  • PyTorch 1.8

Directly install using pip

 pip install -r requirements.txt

Training scripts:

Refer to the scripts folder to train for every model and dataset. Overall the command to train looks like below where each argument can be changed accordingly on how to train. Also refer to dataset/__init__.py and models/__init__.py for correct arguments to train with. Argument parser can be found in utils/argparser.py.

Train with cross-entropy:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss cross_entropy 

Train with FL+MDCA: Also mention the gamma (for Focal Loss) and beta (Weight assigned to MDCA) to train FL+MDCA with

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss FL+MDCA --gamma 1.0 --beta 1.0 

Train with NLL+MDCA:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss NLL+MDCA --beta 1.0

Post Hoc Calibration:

To do post-hoc calibration, we can use the following command.

lr and patience value is used for Dirichlet calibration. To change range of grid-search in dirichlet calibration, refer to posthoc_calibrate.py.

python posthoc_calibrate.py --dataset cifar10 --model resnet56 --lr 0.001 --patience 5 --checkpoint path/to/your/trained/model

Other Experiments (Dataset Drift, Dataset Imbalance):

experiments folder contains our experiments on PACS, Rotated MNIST and Imbalanced CIFAR10. Please refer to the scripts provided to run them.

Citation

If you find our work useful in your research, please cite the following:

@InProceedings{StitchInTime,
    author    = {R. Hebbalaguppe, J. Prakash, N. Madan, C. Arora},
    title     = {A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022}
}

Contact

For questions about our paper or code, please contact any of the authors (@neelabh17, @bicycleman15, @rhebbalaguppe ) or raise an issue on GitHub.

References:

The code is adapted from the following repositories:

[1] bearpaw/pytorch-classification [2] torrvision/focal_calibration [3] Jonathan-Pearce/calibration_library

Owner
MDCA Calibration
MDCA Calibration
Implementation of GGB color space

GGB Color Space This package is implementation of GGB color space from Development of a Robust Algorithm for Detection of Nuclei and Classification of

Resha Dwika Hefni Al-Fahsi 2 Oct 06, 2021
Multi-robot collaborative exploration and mapping through Voronoi partition and DRL in unknown environment

Voronoi Multi_Robot Collaborate Exploration Introduction In the unknown environment, the cooperative exploration of multiple robots is completed by Vo

PeaceWord 6 Nov 22, 2022
Gesture Volume Control Using OpenCV and MediaPipe

This Project Uses OpenCV and MediaPipe Hand solutions to identify hands and Change system volume by taking thumb and index finger positions

Pratham Bhatnagar 6 Sep 12, 2022
A simple implementation of Kalman filter in single object tracking

kalman-filter-in-single-object-tracking A simple implementation of Kalman filter in single object tracking https://www.bilibili.com/video/BV1Qf4y1J7D4

130 Dec 26, 2022
Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing"

One-Shot Free-View Neural Talking Head Synthesis Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Vide

ZLH 406 Dec 23, 2022
converts nominal survey data into a numerical value based on a dictionary lookup.

SWAP RATE Converts nominal survey data into a numerical values based on a dictionary lookup. It allows the user to switch nominal scale data from text

Jake Rhodes 1 Jan 18, 2022
This repo is for segmentation of T2 hyp regions in gliomas.

T2-Hyp-Segmentor This repo is for segmentation of T2 hyp regions in gliomas. By downloading the model from here you can use it to segment your T2w ima

1 Jan 18, 2022
Simple tools for logging and visualizing, loading and training

TNT TNT is a library providing powerful dataloading, logging and visualization utilities for Python. It is closely integrated with PyTorch and is desi

1.5k Jan 02, 2023
Sparse Progressive Distillation: Resolving Overfitting under Pretrain-and-Finetune Paradigm

Sparse Progressive Distillation: Resolving Overfitting under Pretrain-and-Finetu

3 Dec 05, 2022
Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Training and Effective Adaptation", Haoxiang Wang, Han Zhao, Bo Li.

Bridging Multi-Task Learning and Meta-Learning Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Trainin

AI Secure 57 Dec 15, 2022
a morph transfer UGATIT for image translation.

Morph-UGATIT a morph transfer UGATIT for image translation. Introduction 中文技术文档 This is Pytorch implementation of UGATIT, paper "U-GAT-IT: Unsupervise

55 Nov 14, 2022
This is the code for the paper "Contrastive Clustering" (AAAI 2021)

Contrastive Clustering (CC) This is the code for the paper "Contrastive Clustering" (AAAI 2021) Dependency python=3.7 pytorch=1.6.0 torchvision=0.8

Yunfan Li 210 Dec 30, 2022
SEJE Pytorch implementation

SEJE is a prototype for the paper Learning Text-Image Joint Embedding for Efficient Cross-Modal Retrieval with Deep Feature Engineering. Contents Inst

0 Oct 21, 2021
A Model for Natural Language Attack on Text Classification and Inference

TextFooler A Model for Natural Language Attack on Text Classification and Inference This is the source code for the paper: Jin, Di, et al. "Is BERT Re

Di Jin 418 Dec 16, 2022
Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy

Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy Simplex Algorithm is a popular algorithm for linear programmi

Reda BELHAJ 2 Oct 12, 2022
ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

Double-zh 45 Dec 19, 2022
RealTime Emotion Recognizer for Machine Learning Study Jam's demo

Emotion recognizer Table of contents Clone project Dataset Install dependencies Main program Demo 1. Clone project git clone https://github.com/GDSC20

Google Developer Student Club - UIT 1 Oct 05, 2021
Torchyolo - Yolov3 ve Yolov4 modellerin Pytorch uygulamasıdır

TORCHYOLO : Yolo Modellerin Pytorch Uygulaması Yapılacaklar: Yolov3 model.py ve

Kadir Nar 3 Aug 22, 2022
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 08, 2023
FCA: Learning a 3D Full-coverage Vehicle Camouflage for Multi-view Physical Adversarial Attack

FCA: Learning a 3D Full-coverage Vehicle Camouflage for Multi-view Physical Adversarial Attack Case study of the FCA. The code can be find in FCA. Cas

IDRL 21 Dec 15, 2022