Code for the paper "Training GANs with Stronger Augmentations via Contrastive Discriminator" (ICLR 2021)

Overview

Training GANs with Stronger Augmentations via Contrastive Discriminator (ICLR 2021)

This repository contains the code for reproducing the paper: Training GANs with Stronger Augmentations via Contrastive Discriminator by Jongheon Jeong and Jinwoo Shin.

TL;DR: We propose a novel discriminator of GAN showing that contrastive representation learning, e.g., SimCLR, and GAN can benefit each other when they are jointly trained.

Demo

Qualitative comparison of unconditional generations from GANs on high-resoultion, yet limited-sized datasets of AFHQ-Dog (4739 samples), AFHQ-Cat (5153 samples) and AFHQ-Wild (4738 samples) datasets.

Overview

Teaser

An overview of Contrastive Discriminator (ContraD). The representation of ContraD is not learned from the discriminator loss (L_dis), but from two contrastive losses (L+_con and L-_con), each is for the real and fake samples, respectively. The actual discriminator that minimizes L_dis is simply a 2-layer MLP head upon the learned contrastive representation.

Dependencies

Currently, the following environment has been confirmed to run the code:

  • python >= 3.6
  • pytorch >= 1.6.0 (See https://pytorch.org/ for the detailed installation)
  • tensorflow-gpu == 1.14.0 to run test_tf_inception.py for FID/IS evaluations
  • Other requirements can be found in environment.yml (for conda users) or environment_pip.txt (for pip users)
#### Install dependencies via conda.
# The file also includes `pytorch`, `tensorflow-gpu=1.14`, and `cudatoolkit=10.1`.
# You may have to set the correct version of `cudatoolkit` compatible to your system.
# This command creates a new conda environment named `contrad`.
conda env create -f environment.yml

#### Install dependencies via pip.
# It assumes `pytorch` and `tensorflow-gpu` are already installed in the current environment.
pip install -r environment_pip.txt

Preparing datasets

By default, the code assumes that all the datasets are placed under data/. You can change this path by setting the $DATA_DIR environment variable.

CIFAR-10/100 can be automatically downloaded by running any of the provided training scripts.

CelebA-HQ-128:

  1. Download the CelebA-HQ dataset and extract it under $DATA_DIR.
  2. Run third_party/preprocess_celeba_hq.py to resize and split the 1024x1024 images in $DATA_DIR/CelebAMask-HQ/CelebA-HQ-img:
    python third_party/preprocess_celeba_hq.py
    

AFHQ datasets:

  1. Download the AFHQ dataset and extract it under $DATA_DIR.
  2. One has to reorganize the directories in $DATA_DIR/afhq to make it compatible with torchvision.datasets.ImageFolder. Please refer the detailed file structure provided in below.

The structure of $DATA_DIR should be roughly like as follows:

$DATA_DIR
├── cifar-10-batches-py   # CIFAR-10
├── cifar-100-python      # CIFAR-100
├── CelebAMask-HQ         # CelebA-HQ-128
│   ├── CelebA-128-split  # Resized to 128x128 from `CelebA-HQ-img`
│   │   ├── train
│   │   │   └── images
│   │   │       ├── 0.jpg
│   │   │       └── ...
│   │   └── test
│   ├── CelebA-HQ-img     # Original 1024x1024 images
│   ├── CelebA-HQ-to-CelebA-mapping.txt
│   └── README.txt
└── afhq                  # AFHQ datasets
    ├── cat
    │   ├── train
    │   │   └── images
    │   │       ├── flickr_cat_00xxxx.jpg
    │   │       └── ...
    │   └── val
    ├── dog
    └── wild

Scripts

Training Scripts

We provide training scripts to reproduce the results in train_*.py, as listed in what follows:

File Description
train_gan.py Train a GAN model other than StyleGAN2. DistributedDataParallel supported.
train_stylegan2.py Train a StyleGAN2 model. It additionally implements the details of StyleGAN2 training, e.g., R1 regularization and EMA. DataParallel supported.
train_stylegan2_contraD.py Training script optimized for StyleGAN2 + ContraD. It runs faster especially on high-resolution datasets, e.g., 512x512 AFHQ. DataParallel supported.

The samples below demonstrate how to run each script to train GANs with ContraD. More instructions to reproduce our experiments, e.g., other baselines, can be found in EXPERIMENTS.md. One can modify CUDA_VISIBLE_DEVICES to further specify GPU number(s) to work on.

# SNDCGAN + ContraD on CIFAR-10
CUDA_VISIBLE_DEVICES=0 python train_gan.py configs/gan/cifar10/c10_b512.gin sndcgan \
--mode=contrad --aug=simclr --use_warmup

# StyleGAN2 + ContraD on CIFAR-10 - it is OK to simply use `train_stylegan2.py` even with ContraD
python train_stylegan2.py configs/gan/stylegan2/c10_style64.gin stylegan2 \
--mode=contrad --aug=simclr --lbd_r1=0.1 --no_lazy --halflife_k=1000 --use_warmup

# Nevertheless, StyleGAN2 + ContraD can be trained more efficiently with `train_stylegan2_contraD.py` 
python train_stylegan2_contraD.py configs/gan/stylegan2/afhq_dog_style64.gin stylegan2_512 \
--mode=contrad --aug=simclr_hq --lbd_r1=0.5 --halflife_k=20 --use_warmup \
--evaluate_every=5000 --n_eval_avg=1 --no_gif 

Testing Scripts

  • The script test_gan_sample.py generates and saves random samples from a pre-trained generator model into *.jpg files. For example,

    CUDA_VISIBLE_DEVICES=0 python test_gan_sample.py PATH/TO/G.pt sndcgan --n_samples=10000
    

    will load the generator stored at PATH/TO/G.pt, generate n_samples=10000 samples from it, and save them under PATH/TO/samples_*/.

  • The script test_gan_sample_cddls.py additionally takes the discriminator, and a linear evaluation head obtained from test_lineval.py to perform class-conditional cDDLS. For example,

    CUDA_VISIBLE_DEVICES=0 python test_gan_sample_cddls.py LOGDIR PATH/TO/LINEAR.pth.tar sndcgan
    

    will load G and D stored in LOGDIR, the linear head stored at PATH/TO/LINEAR.pth.tar, and save the generated samples from cDDLS under LOGDIR/samples_cDDLS_*/.

  • The script test_lineval.py performs linear evaluation for a given pre-trained discriminator model stored at model_path:

    CUDA_VISIBLE_DEVICES=0 python test_lineval.py PATH/TO/D.pt sndcgan
    
  • The script test_tf_inception.py computes Fréchet Inception distance (FID) and Inception score (IS) with TensorFlow backend using the original code of FID available at https://github.com/bioinf-jku/TTUR. tensorflow-gpu <= 1.14.0 is required to run this script. It takes a directory of generated samples (e.g., via test_gan_sample.py) and an .npz of pre-computed statistics:

    python test_tf_inception.py PATH/TO/GENERATED/IMAGES/ PATH/TO/STATS.npz --n_imgs=10000 --gpu=0 --verbose
    

    A pre-computed statistics file per dataset can be either found in http://bioinf.jku.at/research/ttur/, or manually computed - you can refer third_party/tf/examples for the sample scripts to this end.

Citation

@inproceedings{jeong2021contrad,
  title={Training {GAN}s with Stronger Augmentations via Contrastive Discriminator},
  author={Jongheon Jeong and Jinwoo Shin},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=eo6U4CAwVmg}
}
This is the source code of the 1st place solution for segmentation task (with Dice 90.32%) in 2021 CCF BDCI challenge.

1st place solution in CCF BDCI 2021 ULSEG challenge This is the source code of the 1st place solution for ultrasound image angioma segmentation task (

Chenxu Peng 30 Nov 22, 2022
TensorFlow-LiveLessons - "Deep Learning with TensorFlow" LiveLessons

TensorFlow-LiveLessons Note that the second edition of this video series is now available here. The second edition contains all of the content from th

Deep Learning Study Group 830 Jan 03, 2023
Notebooks, slides and dataset of the CorrelAid Machine Learning Winter School

CorrelAid Machine Learning Winter School Welcome to the CorrelAid ML Winter School! Task The problem we want to solve is to classify trees in Roosevel

CorrelAid 12 Nov 23, 2022
[CVPR'22] Official PyTorch Implementation of Collaborative Transformers for Grounded Situation Recognition

[CVPR'22] Collaborative Transformers for Grounded Situation Recognition Paper | Model Checkpoint This is the official PyTorch implementation of Collab

Junhyeong Cho 29 Dec 10, 2022
[NeurIPS-2020] Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID.

Self-paced Contrastive Learning (SpCL) The official repository for Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID

Yixiao Ge 286 Dec 21, 2022
Materials for upcoming beginner-friendly PyTorch course (work in progress).

Learn PyTorch for Deep Learning (work in progress) I'd like to learn PyTorch. So I'm going to use this repo to: Add what I've learned. Teach others in

Daniel Bourke 2.3k Dec 29, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
ALIbaba's Collection of Encoder-decoders from MinD (Machine IntelligeNce of Damo) Lab

AliceMind AliceMind: ALIbaba's Collection of Encoder-decoders from MinD (Machine IntelligeNce of Damo) Lab This repository provides pre-trained encode

Alibaba 1.4k Jan 01, 2023
The MATH Dataset

Measuring Mathematical Problem Solving With the MATH Dataset This is the repository for Measuring Mathematical Problem Solving With the MATH Dataset b

Dan Hendrycks 267 Dec 26, 2022
A tutorial on DataFrames.jl prepared for JuliaCon2021

JuliaCon2021 DataFrames.jl Tutorial This is a tutorial on DataFrames.jl prepared for JuliaCon2021. A video recording of the tutorial is available here

Bogumił Kamiński 106 Jan 09, 2023
Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM.

HAWQ: Hessian AWare Quantization HAWQ is an advanced quantization library written for PyTorch. HAWQ enables low-precision and mixed-precision uniform

Zhen Dong 293 Dec 30, 2022
[EMNLP 2021] MuVER: Improving First-Stage Entity Retrieval with Multi-View Entity Representations

MuVER This repo contains the code and pre-trained model for our EMNLP 2021 paper: MuVER: Improving First-Stage Entity Retrieval with Multi-View Entity

24 May 30, 2022
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2020

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2020

Phillip Lippe 1.1k Jan 07, 2023
Code and data for "Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning" (EMNLP 2021).

GD-VCR Code for Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning (EMNLP 2021). Research Questions and Aims: How well can a model perform o

Da Yin 24 Oct 13, 2022
A pre-trained language model for social media text in Spanish

RoBERTuito A pre-trained language model for social media text in Spanish READ THE FULL PAPER Github Repository RoBERTuito is a pre-trained language mo

25 Dec 29, 2022
Metric learning algorithms in Python

metric-learn: Metric Learning in Python metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised met

1.3k Jan 02, 2023
A benchmark for the task of translation suggestion

WeTS: A Benchmark for Translation Suggestion Translation Suggestion (TS), which provides alternatives for specific words or phrases given the entire d

zhyang 55 Dec 24, 2022
Pytorch implement of 'Unmixing based PAN guided fusion network for hyperspectral imagery'

Pgnet There's a improved version compared with the publication in Tgrs with the modification in the deduction of the PDIN block: https://arxiv.org/abs

5 Jul 01, 2022
Learning to Segment Instances in Videos with Spatial Propagation Network

Learning to Segment Instances in Videos with Spatial Propagation Network This paper is available at the 2017 DAVIS Challenge website. Check our result

Jingchun Cheng 145 Sep 28, 2022
Recurrent Neural Network Tutorial, Part 2 - Implementing a RNN in Python and Theano

Please read the blog post that goes with this code! Jupyter Notebook Setup System Requirements: Python, pip (Optional) virtualenv To start the Jupyter

Denny Britz 863 Dec 15, 2022