ICLR 2021 i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning

Related tags

Deep Learningimix
Overview

Introduction

PyTorch code for the ICLR 2021 paper [i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning].

@inproceedings{lee2021imix,
  title={i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning},
  author={Lee, Kibok and Zhu, Yian and Sohn, Kihyuk and Li, Chun-Liang and Shin, Jinwoo and Lee, Honglak},
  booktitle={ICLR},
  year={2021}
}

Dependencies

  • python 3.7.4
  • numpy 1.17.2
  • pytorch 1.4.0
  • torchvision 0.5.0
  • cudatoolkit 10.1
  • librosa 0.8.0 for speech_commands
  • PIL 6.2.0 for GaussianBlur

Data

  • CIFAR-10/100 will automatically be downloaded.
  • For ImageNet, please refer to the [PyTorch ImageNet example]. The folder structure should be like data/imagenet/train/n01440764/
  • For speech commands, run bash speech_commands/download_speech_commands_dataset.sh.
  • For tabular datasets, download [covtype.data.gz] and [HIGGS.csv.gz], and place them in data/. They are processed when first loaded.

Running scripts

Please refer to [run.sh].

Plug-in example

For those who want to apply our method in their own code, we provide a minimal example based on [MoCo]:

# mixup: somewhere in main_moco.py
def mixup(input, alpha):
    beta = torch.distributions.beta.Beta(alpha, alpha)
    randind = torch.randperm(input.shape[0], device=input.device)
    lam = beta.sample([input.shape[0]]).to(device=input.device)
    lam = torch.max(lam, 1. - lam)
    lam_expanded = lam.view([-1] + [1]*(input.dim()-1))
    output = lam_expanded * input + (1. - lam_expanded) * input[randind]
    return output, randind, lam

# cutmix: somewhere in main_moco.py
def cutmix(input, alpha):
    beta = torch.distributions.beta.Beta(alpha, alpha)
    randind = torch.randperm(input.shape[0], device=input.device)
    lam = beta.sample().to(device=input.device)
    lam = torch.max(lam, 1. - lam)
    (bbx1, bby1, bbx2, bby2), lam = rand_bbox(input.shape[-2:], lam)
    output = input.clone()
    output[..., bbx1:bbx2, bby1:bby2] = output[randind][..., bbx1:bbx2, bby1:bby2]
    return output, randind, lam

def rand_bbox(size, lam):
    W, H = size
    cut_rat = (1. - lam).sqrt()
    cut_w = (W * cut_rat).to(torch.long)
    cut_h = (H * cut_rat).to(torch.long)

    cx = torch.zeros_like(cut_w, dtype=cut_w.dtype).random_(0, W)
    cy = torch.zeros_like(cut_h, dtype=cut_h.dtype).random_(0, H)

    bbx1 = (cx - cut_w // 2).clamp(0, W)
    bby1 = (cy - cut_h // 2).clamp(0, H)
    bbx2 = (cx + cut_w // 2).clamp(0, W)
    bby2 = (cy + cut_h // 2).clamp(0, H)

    new_lam = 1. - (bbx2 - bbx1).to(lam.dtype) * (bby2 - bby1).to(lam.dtype) / (W * H)

    return (bbx1, bby1, bbx2, bby2), new_lam

# https://github.com/facebookresearch/moco/blob/master/main_moco.py#L193
criterion = nn.CrossEntropyLoss(reduction='none').cuda(args.gpu)

# https://github.com/facebookresearch/moco/blob/master/main_moco.py#L302-L303
images[0], target_aux, lam = mixup(images[0], alpha=1.)
# images[0], target_aux, lam = cutmix(images[0], alpha=1.)
target = torch.arange(images[0].shape[0], dtype=torch.long).cuda()
output, _ = model(im_q=images[0], im_k=images[1])
loss = lam * criterion(output, target) + (1. - lam) * criterion(output, target_aux)

# https://github.com/facebookresearch/moco/blob/master/moco/builder.py#L142-L149
contrast = torch.cat([k, self.queue.clone().detach().t()], dim=0)
logits = torch.mm(q, contrast.t())

Note

Owner
Kibok Lee
Kibok Lee
Checkout some cool self-projects you can try your hands on to curb your boredom this December!

SoC-Winter Checkout some cool self-projects you can try your hands on to curb your boredom this December! These are short projects that you can do you

Web and Coding Club, IIT Bombay 29 Nov 08, 2022
Marine debris detection with commercial satellite imagery and deep learning.

Marine debris detection with commercial satellite imagery and deep learning. Floating marine debris is a global pollution problem which threatens mari

Inter Agency Implementation and Advanced Concepts 56 Dec 16, 2022
CenterNet:Objects as Points目标检测模型在Pytorch当中的实现

CenterNet:Objects as Points目标检测模型在Pytorch当中的实现

Bubbliiiing 267 Dec 29, 2022
A Python library for common tasks on 3D point clouds

Point Cloud Utils (pcu) - A Python library for common tasks on 3D point clouds Point Cloud Utils (pcu) is a utility library providing the following fu

Francis Williams 622 Dec 27, 2022
Machine Learning with JAX Tutorials

The purpose of this repo is to make it easy to get started with JAX. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I fou

Aleksa Gordić 372 Dec 28, 2022
Unit-Convertor - Unit Convertor Built With Python

Python Unit Converter This project can convert Weigth,length and ... units for y

Mahdis Esmaeelian 1 May 31, 2022
Semantic-aware Grad-GAN for Virtual-to-Real Urban Scene Adaption

SG-GAN TensorFlow implementation of SG-GAN. Prerequisites TensorFlow (implemented in v1.3) numpy scipy pillow Getting Started Train Prepare dataset. W

lplcor 61 Jun 07, 2022
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming"

Coresets via Bilevel Optimization This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming" ht

Zalán Borsos 51 Dec 30, 2022
Decision Transformer: A brand new Offline RL Pattern

DecisionTransformer_StepbyStep Intro Decision Transformer: A brand new Offline RL Pattern. 这是关于NeurIPS 2021 热门论文Decision Transformer的复现。 👍 原文地址: Deci

Irving 14 Nov 22, 2022
patchmatch和patchmatchstereo算法的python实现

patchmatch patchmatch以及patchmatchstereo算法的python版实现 patchmatch参考 github patchmatchstereo参考李迎松博士的c++版代码 由于patchmatchstereo没有做任何优化,并且是python的代码,主要是方便解析算

Sanders Bao 11 Dec 02, 2022
3D dataset of humans Manipulating Objects in-the-Wild (MOW)

MOW dataset [Website] This repository maintains our 3D dataset of humans Manipulating Objects in-the-Wild (MOW). The dataset contains 512 images in th

Zhe Cao 28 Nov 06, 2022
A `Neural = Symbolic` framework for sound and complete weighted real-value logic

Logical Neural Networks LNNs are a novel Neuro = symbolic framework designed to seamlessly provide key properties of both neural nets (learning) and s

International Business Machines 138 Dec 19, 2022
python 93% acc. CNN Dogs Vs Cats ( Pytorch )

English | 简体中文(测试中...敬请期待) Cnn-Classification-Dog-Vs-Cat 猫狗辨别 (pytorch版本) CNN Resnet18 的猫狗分类器,基于ResNet及其变体网路系列,对于一般的图像识别任务表现优异,模型精准度高达93%(小型样本)。 项目制作于

apple ye 1 May 22, 2022
A state of the art of new lightweight YOLO model implemented by TensorFlow 2.

CSL-YOLO: A New Lightweight Object Detection System for Edge Computing This project provides a SOTA level lightweight YOLO called "Cross-Stage Lightwe

Miles Zhang 54 Dec 21, 2022
Code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectrograms, using the PyTorch Lightning.

stereoEEG2speech We provide code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectro

15 Nov 11, 2022
Automatic deep learning for image classification.

AutoDL AutoDL automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications. With just a few line

wenqi 2 Oct 12, 2022
Fast SHAP value computation for interpreting tree-based models

FastTreeSHAP FastTreeSHAP package is built based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees published in NeurIPS 2021 X

LinkedIn 369 Jan 04, 2023
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets

HOW TO USE THIS PROJECT A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets Based on DeepLabCut toolbox, we run wit

1 Jan 10, 2022