[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Related tags

Deep LearningMAK
Overview

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling

Introduction

Contrastive learning approaches have achieved great success in learning visual representations with few labels. That implies a tantalizing possibility of scaling them up beyond a curated target benchmark, to incorporating more unlabeled images from the internet-scale external sources to enhance its performance. However, in practice, with larger amount of unlabeled data, it requires more compute resources for the bigger model size and longer training. Moreover, open-world unlabeled data have implicit long-tail distribution of various class attributes, many of which are out of distribution and can lead to data imbalancedness issue. This motivates us to seek a principled approach of selecting a subset of unlabeled data from an external source that are relevant for learning better and diverse representations. In this work, we propose an open-world unlabeled data sampling strategy called Model-Aware K-center (MAK), which follows three simple principles: (1) tailness, which encourages sampling of examples from tail classes, by sorting the empirical contrastive loss expectation (ECLE) of samples over random data augmentations; (2) proximity, which rejects the out-of-distribution outliers that might distract training; and (3) diversity, which ensures diversity in the set of sampled examples. Empirically, using ImageNet-100-LT (without labels) as the target dataset and two ``noisy'' external data sources, we demonstrate that MAK can consistently improve both the overall representation quality and class balancedness of the learned features, as evaluated via linear classifier evaluation on full-shot and few-shot settings.

Method

pipeline

Environment

Requirements:

pytorch 1.7.1 
opencv-python
kmeans-pytorch 0.3
scikit-learn

Recommend installation cmds (linux)

conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch # change cuda version according to hardware
pip install opencv-python
conda install -c conda-forge matplotlib scikit-learn

Sampling

Prepare

change the access permissions

chmod +x  cmds/shell_scrips/*

Get pre-trained model on LT datasets

bash ./cmds/shell_scrips/imagenet-100-add-data.sh -g 2 -p 4866 -w 10 --seed 10 --additional_dataset None

Sampling on ImageNet 900

Inference

inference on sampling dataset (no Aug)

bash ./cmds/shell_scrips/imagenet-100-inference.sh -p 5555 --workers 10 --pretrain_seed 10 \
--epochs 1000 --batch_size 256 --inference_dataset imagenet-900 --inference_dataset_split ImageNet_900_train \
--inference_repeat_time 1 --inference_noAug True

inference on sampling dataset (no Aug)

bash ./cmds/shell_scrips/imagenet-100-inference.sh -p 5555 --workers 10 --pretrain_seed 10 \
--epochs 1000 --batch_size 256 --inference_dataset imagenet-100 --inference_dataset_split imageNet_100_LT_train \
--inference_repeat_time 1 --inference_noAug True

inference on sampling dataset (w/ Aug)

bash ./cmds/shell_scrips/imagenet-100-inference.sh -p 5555 --workers 10 --pretrain_seed 10 \
--epochs 1000 --batch_size 256 --inference_dataset imagenet-900 --inference_dataset_split ImageNet_900_train \
--inference_repeat_time 10

sampling 10K at Imagenet900

bash ./cmds/shell_scrips/sampling.sh --pretrain_seed 10

Citation

@inproceedings{
jiang2021improving,
title={Improving Contrastive Learning on Imbalanced Data via Open-World Sampling},
author={Jiang, Ziyu and Chen, Tianlong and Chen, Ting and Wang, Zhangyang},
booktitle={Advances in Neural Information Processing Systems 35},
year={2021}
}
Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
A Web API for automatic background removal using Deep Learning. App is made using Flask and deployed on Heroku.

Automatic_Background_Remover A Web API for automatic background removal using Deep Learning. App is made using Flask and deployed on Heroku. 👉 https:

Gaurav 16 Oct 29, 2022
一个多模态内容理解算法框架,其中包含数据处理、预训练模型、常见模型以及模型加速等模块。

Overview 架构设计 插件介绍 安装使用 框架简介 方便使用,支持多模态,多任务的统一训练框架 能力列表: bert + 分类任务 自定义任务训练(插件注册) 框架设计 框架采用分层的思想组织模型训练流程。 DATA 层负责读取用户数据,根据 field 管理数据。 Parser 层负责转换原

Tencent 265 Dec 22, 2022
Flaxformer: transformer architectures in JAX/Flax

Flaxformer is a transformer library for primarily NLP and multimodal research at Google.

Google 116 Jan 05, 2023
用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本和PARL(paddle)版本

用强化学习玩合成大西瓜 代码地址:https://github.com/Sharpiless/play-daxigua-using-Reinforcement-Learning 用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本、PARL(paddle)版本和pytorch版本

72 Dec 17, 2022
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Rishabh Anand 24 Mar 23, 2022
Introducing neural networks to predict stock prices

IntroNeuralNetworks in Python: A Template Project IntroNeuralNetworks is a project that introduces neural networks and illustrates an example of how o

Vivek Palaniappan 637 Jan 04, 2023
Genetic Programming in Python, with a scikit-learn inspired API

Welcome to gplearn! gplearn implements Genetic Programming in Python, with a scikit-learn inspired and compatible API. While Genetic Programming (GP)

Trevor Stephens 1.3k Jan 03, 2023
Official implementation of Deep Burst Super-Resolution

Deep-Burst-SR Official implementation of Deep Burst Super-Resolution Publication: Deep Burst Super-Resolution. Goutam Bhat, Martin Danelljan, Luc Van

Goutam Bhat 113 Dec 19, 2022
This porject is intented to build the most accurate model for predicting the porbability of loan default

Estimating-Loan-Default-Probability IBA ML2 Mid-project / Kaggle Competition This porject is intented to build the most accurate model for predicting

Adil Gahramanov 1 Jan 24, 2022
This repository introduces a short project about Transfer Learning for Classification of MRI Images.

Transfer Learning for MRI Images Classification This repository introduces a short project made during my stay at Neuromatch Summer School 2021. This

Oscar Guarnizo 3 Nov 15, 2022
Auto White-Balance Correction for Mixed-Illuminant Scenes

Auto White-Balance Correction for Mixed-Illuminant Scenes Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown York University Video Reference code

Mahmoud Afifi 47 Nov 26, 2022
A quick recipe to learn all about Transformers

Transformers have accelerated the development of new techniques and models for natural language processing (NLP) tasks.

DAIR.AI 772 Dec 31, 2022
Code for "PVNet: Pixel-wise Voting Network for 6DoF Pose Estimation" CVPR 2019 oral

Good news! We release a clean version of PVNet: clean-pvnet, including how to train the PVNet on the custom dataset. Use PVNet with a detector. The tr

ZJU3DV 722 Dec 27, 2022
LieTransformer: Equivariant Self-Attention for Lie Groups

LieTransformer This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant Self-At

OxCSML (Oxford Computational Statistics and Machine Learning) 50 Dec 28, 2022
An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results

EasyDatas An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results Installation pip install git+https

Ximing Yang 4 Dec 14, 2021
Simple implementation of Mobile-Former on Pytorch

Simple-implementation-of-Mobile-Former At present, only the model but no trained. There may be some bug in the code, and some details may be different

Acheung 103 Dec 31, 2022
improvement of CLIP features over the traditional resnet features on the visual question answering, image captioning, navigation and visual entailment tasks.

CLIP-ViL In our paper "How Much Can CLIP Benefit Vision-and-Language Tasks?", we show the improvement of CLIP features over the traditional resnet fea

310 Dec 28, 2022
A Pytorch implementation of the multi agent deep deterministic policy gradients (MADDPG) algorithm

Multi-Agent-Deep-Deterministic-Policy-Gradients A Pytorch implementation of the multi agent deep deterministic policy gradients(MADDPG) algorithm This

Phil Tabor 159 Dec 28, 2022
ML models implementation practice

Let's implement various ML algorithms with numpy/tf Vanilla Neural Network https://towardsdatascience.com/lets-code-a-neural-network-in-plain-numpy-ae

Jinsoo Heo 4 Jul 04, 2021
CUda Matrix Multiply library.

cumm CUda Matrix Multiply library. cumm is developed during learning of CUTLASS, which use too much c++ template and make code unmaintainable. So I de

49 Dec 27, 2022