[NeurIPS 2020] Semi-Supervision (Unlabeled Data) & Self-Supervision Improve Class-Imbalanced / Long-Tailed Learning

Overview

Rethinking the Value of Labels for Improving Class-Imbalanced Learning

This repository contains the implementation code for paper:
Rethinking the Value of Labels for Improving Class-Imbalanced Learning
Yuzhe Yang, and Zhi Xu
34th Conference on Neural Information Processing Systems (NeurIPS), 2020
[Website] [arXiv] [Paper] [Slides] [Video]

If you find this code or idea useful, please consider citing our work:

@inproceedings{yang2020rethinking,
  title={Rethinking the Value of Labels for Improving Class-Imbalanced Learning},
  author={Yang, Yuzhe and Xu, Zhi},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2020}
}

Overview

In this work, we show theoretically and empirically that, both semi-supervised learning (using unlabeled data) and self-supervised pre-training (first pre-train the model with self-supervision) can substantially improve the performance on imbalanced (long-tailed) datasets, regardless of the imbalanceness on labeled/unlabeled data and the base training techniques.

Semi-Supervised Imbalanced Learning: Using unlabeled data helps to shape clearer class boundaries and results in better class separation, especially for the tail classes. semi

Self-Supervised Imbalanced Learning: Self-supervised pre-training (SSP) helps mitigate the tail classes leakage during testing, which results in better learned boundaries and representations. self

Installation

Prerequisites

Dependencies

  • PyTorch (>= 1.2, tested on 1.4)
  • yaml
  • scikit-learn
  • TensorboardX

Code Overview

Main Files

Main Arguments

  • --dataset: name of chosen long-tailed dataset
  • --imb_factor: imbalance factor (inverse value of imbalance ratio \rho in the paper)
  • --imb_factor_unlabel: imbalance factor for unlabeled data (inverse value of unlabel imbalance ratio \rho_U)
  • --pretrained_model: path to self-supervised pre-trained models
  • --resume: path to resume checkpoint (also for evaluation)

Getting Started

Semi-Supervised Imbalanced Learning

Unlabeled data sourcing

CIFAR-10-LT: CIFAR-10 unlabeled data is prepared following this repo using the 80M TinyImages. In short, a data sourcing model is trained to distinguish CIFAR-10 classes and an "non-CIFAR" class. For each class, images are then ranked based on the prediction confidence, and unlabeled (imbalanced) datasets are constructed accordingly. Use the following link to download the prepared unlabeled data, and place in your data_path:

SVHN-LT: Since its own dataset contains an extra part with 531.1K additional (labeled) samples, they are directly used to simulate the unlabeled dataset.

Note that the class imbalance in unlabeled data is also considered, which is controlled by --imb_factor_unlabel (\rho_U in the paper). See imbalance_cifar.py and imbalance_svhn.py for details.

Semi-supervised learning with pseudo-labeling

To perform pseudo-labeling (self-training), first a base classifier is trained on original imbalanced dataset. With the trained base classifier, pseudo-labels can be generated using

python gen_pseudolabels.py --resume <ckpt-path> --data_dir <data_path> --output_dir <output_path> --output_filename <save_name>

We provide generated pseudo label files for CIFAR-10-LT & SVHN-LT with \rho=50, using base models trained with standard cross-entropy (CE) loss:

To train with unlabeled data, for example, on CIFAR-10-LT with \rho=50 and \rho_U=50

python train_semi.py --dataset cifar10 --imb_factor 0.02 --imb_factor_unlabel 0.02

Self-Supervised Imbalanced Learning

Self-supervised pre-training (SSP)

To perform Rotation SSP on CIFAR-10-LT with \rho=100

python pretrain_rot.py --dataset cifar10 --imb_factor 0.01

To perform MoCo SSP on ImageNet-LT

python pretrain_moco.py --dataset imagenet --data <data_path>

Network training with SSP models

Train on CIFAR-10-LT with \rho=100

python train.py --dataset cifar10 --imb_factor 0.01 --pretrained_model <path_to_ssp_model>

Train on ImageNet-LT / iNaturalist 2018

python -m imagenet_inat.main --cfg <path_to_ssp_config> --model_dir <path_to_ssp_model>

Results and Models

All related data and checkpoints can be found via this link. Individual results and checkpoints are detailed as follows.

Semi-Supervised Imbalanced Learning

CIFAR-10-LT

Model Top-1 Error Download
CE + [email protected] (\rho=50 and \rho_U=1) 16.79 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=25) 16.88 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=50) 18.36 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=100) 19.94 ResNet-32

SVHN-LT

Model Top-1 Error Download
CE + [email protected] (\rho=50 and \rho_U=1) 13.07 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=25) 13.36 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=50) 13.16 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=100) 14.54 ResNet-32

Test a pretrained checkpoint

python train_semi.py --dataset cifar10 --resume <ckpt-path> -e

Self-Supervised Imbalanced Learning

CIFAR-10-LT

  • Self-supervised pre-trained models (Rotation)

    Dataset Setting \rho=100 \rho=50 \rho=10
    Download ResNet-32 ResNet-32 ResNet-32
  • Final models (200 epochs)

    Model \rho Top-1 Error Download
    CE(Uniform) + SSP 10 12.28 ResNet-32
    CE(Uniform) + SSP 50 21.80 ResNet-32
    CE(Uniform) + SSP 100 26.50 ResNet-32
    CE(Balanced) + SSP 10 11.57 ResNet-32
    CE(Balanced) + SSP 50 19.60 ResNet-32
    CE(Balanced) + SSP 100 23.47 ResNet-32

CIFAR-100-LT

  • Self-supervised pre-trained models (Rotation)

    Dataset Setting \rho=100 \rho=50 \rho=10
    Download ResNet-32 ResNet-32 ResNet-32
  • Final models (200 epochs)

    Model \rho Top-1 Error Download
    CE(Uniform) + SSP 10 42.93 ResNet-32
    CE(Uniform) + SSP 50 54.96 ResNet-32
    CE(Uniform) + SSP 100 59.60 ResNet-32
    CE(Balanced) + SSP 10 41.94 ResNet-32
    CE(Balanced) + SSP 50 52.91 ResNet-32
    CE(Balanced) + SSP 100 56.94 ResNet-32

ImageNet-LT

  • Self-supervised pre-trained models (MoCo)
    [ResNet-50]

  • Final models (90 epochs)

    Model Top-1 Error Download
    CE(Uniform) + SSP 54.4 ResNet-50
    CE(Balanced) + SSP 52.4 ResNet-50
    cRT + SSP 48.7 ResNet-50

iNaturalist 2018

  • Self-supervised pre-trained models (MoCo)
    [ResNet-50]

  • Final models (90 epochs)

    Model Top-1 Error Download
    CE(Uniform) + SSP 35.6 ResNet-50
    CE(Balanced) + SSP 34.1 ResNet-50
    cRT + SSP 31.9 ResNet-50

Test a pretrained checkpoint

# test on CIFAR-10 / CIFAR-100
python train.py --dataset cifar10 --resume <ckpt-path> -e

# test on ImageNet-LT / iNaturalist 2018
python -m imagenet_inat.main --cfg <path_to_ssp_config> --model_dir <path_to_model> --test

Acknowledgements

This code is partly based on the open-source implementations from the following sources: OpenLongTailRecognition, classifier-balancing, LDAM-DRW, MoCo, and semisup-adv.

Contact

If you have any questions, feel free to contact us through email ([email protected] & [email protected]) or Github issues. Enjoy!

Owner
Yuzhe Yang
Ph.D. student at MIT CSAIL
Yuzhe Yang
Modified prey-predator system - Modified prey–predator model describes the rate of change for each species by adding coupling terms.

Modified prey-predator system We aim to study the behaviors of the modified prey–predator model and establish the effects of several parameters that p

Seoyoung Oh 1 Jan 02, 2022
A copy of Ares that costs 30 fucking dollars.

Finalement, j'ai décidé d'abandonner cette idée, je me suis comporté comme un enfant qui été en colère. Comme m'ont dit certaines personnes j'ai des c

Bleu 24 Apr 14, 2022
Yolov5-lite - Minimal PyTorch implementation of YOLOv5

Yolov5-Lite: Minimal YOLOv5 + Deep Sort Overview This repo is a shortened versio

Kadir Nar 57 Nov 28, 2022
Offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation

Shunted Transformer This is the offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation by Sucheng Ren, Daquan Zhou, Shengf

156 Dec 27, 2022
Simple Pixelbot for Diablo 2 Resurrected written in python and opencv.

Simple Pixelbot for Diablo 2 Resurrected written in python and opencv. Obviously only use it in offline mode as it is against the TOS of Blizzard to use it in online mode!

468 Jan 03, 2023
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

AutoViz and Auto_ViML 397 Dec 30, 2022
Fast and robust certifiable relative pose estimation

Fast and Robust Relative Pose Estimation for Calibrated Cameras This repository contains the code for the relative pose estimation between two central

42 Dec 06, 2022
A deep learning framework for historical document image analysis

DIVA-DAF Description A deep learning framework for historical document image analysis. How to run Install dependencies # clone project git clone https

9 Aug 04, 2022
VGG16 model-based classification project about brain tumor detection.

Brain-Tumor-Classification-with-MRI VGG16 model-based classification project about brain tumor detection. First, you can check what people are doing o

Atakan Erdoğan 2 Mar 21, 2022
Standalone pre-training recipe with JAX+Flax

Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b

Nikita Kitaev 26 Nov 28, 2022
PyTorch code for Composing Partial Differential Equations with Physics-Aware Neural Networks

FInite volume Neural Network (FINN) This repository contains the PyTorch code for models, training, and testing, and Python code for data generation t

Cognitive Modeling 20 Dec 18, 2022
Official Implementation of SWAGAN: A Style-based Wavelet-driven Generative Model

Official Implementation of SWAGAN: A Style-based Wavelet-driven Generative Model SWAGAN: A Style-based Wavelet-driven Generative Model Rinon Gal, Dana

55 Dec 06, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
Scheme for training and applying a label propagation framework

Factorisation-based Image Labelling Overview This is a scheme for training and applying the factorisation-based image labelling (FIL) framework. Some

Wellcome Centre for Human Neuroimaging 2 Dec 17, 2021
PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime Created by Prarthana Bhattacharyya. Disclaimer: This is n

Prarthana Bhattacharyya 5 Nov 08, 2022
PyTorch implementation of the wavelet analysis from Torrence & Compo

Continuous Wavelet Transforms in PyTorch This is a PyTorch implementation for the wavelet analysis outlined in Torrence and Compo (BAMS, 1998). The co

Tom Runia 262 Dec 21, 2022
Checking fibonacci - Generating the Fibonacci sequence is a classic recursive problem

Fibonaaci Series Generating the Fibonacci sequence is a classic recursive proble

Moureen Caroline O 1 Feb 15, 2022
Implementation of "Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner"

Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner This repository is the official implementation of Meta-rPPG: Remote Heart Ra

Eugene Lee 137 Dec 13, 2022
Edison AT is software Depression Assistant personal.

Edison AT Edison AT is software / program Depression Assistant personal. Feature: Analyze emotional real-time from face. Audio Edison(Comingsoon relea

Ananda Rauf 2 Apr 24, 2022
Bringing Computer Vision and Flutter together , to build an awesome app !!

Bringing Computer Vision and Flutter together , to build an awesome app !! Explore the Directories Flutter · Machine Learning Table of Contents About

Padmanabha Banerjee 14 Apr 07, 2022