Ensembling Off-the-shelf Models for GAN Training

Overview

Data-Efficient GANs with DiffAugment

project | paper | datasets | video | slides

Generated using only 100 images of Obama, grumpy cats, pandas, the Bridge of Sighs, the Medici Fountain, the Temple of Heaven, without pre-training.

[NEW!] PyTorch training with DiffAugment-stylegan2-pytorch is now available!

[NEW!] Our Colab tutorial is released!

[NEW!] FFHQ training is supported! See the DiffAugment-stylegan2 README.

[NEW!] Time to generate 100-shot interpolation videos with generate_gif.py!

[NEW!] Our DiffAugment-biggan-imagenet repo (for TPU training) is released!

[NEW!] Our DiffAugment-biggan-cifar PyTorch repo is released!

This repository contains our implementation of Differentiable Augmentation (DiffAugment) in both PyTorch and TensorFlow. It can be used to significantly improve the data efficiency for GAN training. We have provided DiffAugment-stylegan2 (TensorFlow) and DiffAugment-stylegan2-pytorch, DiffAugment-biggan-cifar (PyTorch) for GPU training, and DiffAugment-biggan-imagenet (TensorFlow) for TPU training.

Low-shot generation without pre-training. With DiffAugment, our model can generate high-fidelity images using only 100 Obama portraits, grumpy cats, or pandas from our collected 100-shot datasets, 160 cats or 389 dogs from the AnimalFace dataset at 256×256 resolution.

Unconditional generation results on CIFAR-10. StyleGAN2’s performance drastically degrades given less training data. With DiffAugment, we are able to roughly match its FID and outperform its Inception Score (IS) using only 20% training data.

Differentiable Augmentation for Data-Efficient GAN Training
Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
MIT, Tsinghua University, Adobe Research, CMU
arXiv

Overview

Overview of DiffAugment for updating D (left) and G (right). DiffAugment applies the augmentation T to both the real sample x and the generated output G(z). When we update G, gradients need to be back-propagated through T (iii), which requires T to be differentiable w.r.t. the input.

Training and Generation with 100 Images

To generate an interpolation video using our pre-trained models:

cd DiffAugment-stylegan2
python generate_gif.py -r mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl -o obama.gif

or to train a new model:

python run_low_shot.py --dataset=100-shot-obama --num-gpus=4

You may also try out 100-shot-grumpy_cat, 100-shot-panda, 100-shot-bridge_of_sighs, 100-shot-medici_fountain, 100-shot-temple_of_heaven, 100-shot-wuzhen, or the folder containing your own training images. Please refer to the DiffAugment-stylegan2 README for the dependencies and details.

[NEW!] PyTorch training is now available:

cd DiffAugment-stylegan2-pytorch
python train.py --outdir=training-runs --data=https://data-efficient-gans.mit.edu/datasets/100-shot-obama.zip --gpus=1

DiffAugment for StyleGAN2

To run StyleGAN2 + DiffAugment for unconditional generation on the 100-shot datasets, CIFAR, FFHQ, or LSUN, please refer to the DiffAugment-stylegan2 README or DiffAugment-stylegan2-pytorch for the PyTorch version.

DiffAugment for BigGAN

Please refer to the DiffAugment-biggan-cifar README to run BigGAN + DiffAugment for conditional generation on CIFAR (using GPUs), and the DiffAugment-biggan-imagenet README to run on ImageNet (using TPUs).

Using DiffAugment for Your Own Training

To help you use DiffAugment in your own codebase, we provide portable DiffAugment operations of both TensorFlow and PyTorch versions in DiffAugment_tf.py and DiffAugment_pytorch.py. Generally, DiffAugment can be easily adopted in any model by substituting every D(x) with D(T(x)), where x can be real images or fake images, D is the discriminator, and T is the DiffAugment operation. For example,

from DiffAugment_pytorch import DiffAugment
# from DiffAugment_tf import DiffAugment
policy = 'color,translation,cutout' # If your dataset is as small as ours (e.g.,
# hundreds of images), we recommend using the strongest Color + Translation + Cutout.
# For large datasets, try using a subset of transformations in ['color', 'translation', 'cutout'].
# Welcome to discover more DiffAugment transformations!

...
# Training loop: update D
reals = sample_real_images() # a batch of real images
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
real_scores = Discriminator(DiffAugment(reals, policy=policy))
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating D's loss based on real_scores and fake_scores...
...

...
# Training loop: update G
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating G's loss based on fake_scores...
...

We have implemented Color, Translation, and Cutout DiffAugment as visualized below:

Citation

If you find this code helpful, please cite our paper:

@inproceedings{zhao2020diffaugment,
  title={Differentiable Augmentation for Data-Efficient GAN Training},
  author={Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2020}
}

Acknowledgements

We thank NSF Career Award #1943349, MIT-IBM Watson AI Lab, Google, Adobe, and Sony for supporting this research. Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC). We thank William S. Peebles and Yijun Li for helpful comments.

Owner
MIT HAN Lab
Accelerating Deep Learning Computing
MIT HAN Lab
The world's largest toxicity dataset.

The Toxicity Dataset by Surge AI Saving the internet is fun. Combing through thousands of online comments to build a toxicity dataset isn't. That's wh

Surge AI 134 Dec 19, 2022
Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation

SSWS-loss_function_based_on_MS-TCN Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation Supervised Sliding Window

3 Aug 03, 2022
FindFunc is an IDA PRO plugin to find code functions that contain a certain assembly or byte pattern, reference a certain name or string, or conform to various other constraints.

FindFunc: Advanced Filtering/Finding of Functions in IDA Pro FindFunc is an IDA Pro plugin to find code functions that contain a certain assembly or b

213 Dec 17, 2022
Tools for manipulating UVs in the Blender viewport.

UV Tool Suite for Blender A set of tools to make editing UVs easier in Blender. These tools can be accessed wither through the Kitfox - UV panel on th

35 Oct 29, 2022
Deep Learning agent of Starcraft2, similar to AlphaStar of DeepMind except size of network.

Introduction This repository is for Deep Learning agent of Starcraft2. It is very similar to AlphaStar of DeepMind except size of network. I only test

Dohyeong Kim 136 Jan 04, 2023
[ICCV 2021] Deep Hough Voting for Robust Global Registration

Deep Hough Voting for Robust Global Registration, ICCV, 2021 Project Page | Paper | Video Deep Hough Voting for Robust Global Registration Junha Lee1,

57 Nov 28, 2022
Ian Covert 130 Jan 01, 2023
Materials for upcoming beginner-friendly PyTorch course (work in progress).

Learn PyTorch for Deep Learning (work in progress) I'd like to learn PyTorch. So I'm going to use this repo to: Add what I've learned. Teach others in

Daniel Bourke 2.3k Dec 29, 2022
pix2pix in tensorflow.js

pix2pix in tensorflow.js This repo is moved to https://github.com/yining1023/pix2pix_tensorflowjs_lite See a live demo here: https://yining1023.github

Yining Shi 47 Oct 04, 2022
DSAC* for Visual Camera Re-Localization (RGB or RGB-D)

DSAC* for Visual Camera Re-Localization (RGB or RGB-D) Introduction Installation Data Structure Supported Datasets 7Scenes 12Scenes Cambridge Landmark

Visual Learning Lab 143 Dec 22, 2022
The implementation for "Comprehensive Knowledge Distillation with Causal Intervention".

Comprehensive Knowledge Distillation with Causal Intervention This repository is a PyTorch implementation of "Comprehensive Knowledge Distillation wit

Xiang Deng 10 Nov 03, 2022
[CVPR 2019 Oral] Multi-Channel Attention Selection GAN with Cascaded Semantic Guidance for Cross-View Image Translation

SelectionGAN for Guided Image-to-Image Translation CVPR Paper | Extended Paper | Guided-I2I-Translation-Papers Citation If you use this code for your

Hao Tang 424 Dec 02, 2022
Representing Long-Range Context for Graph Neural Networks with Global Attention

Graph Augmentation Graph augmentation/self-supervision/etc. Algorithms gcn gcn+virtual node gin gin+virtual node PNA GraphTrans Augmentation methods N

UC Berkeley RISE 67 Dec 30, 2022
Code for Discriminative Sounding Objects Localization (NeurIPS 2020)

Discriminative Sounding Objects Localization Code for our NeurIPS 2020 paper Discriminative Sounding Objects Localization via Self-supervised Audiovis

51 Dec 11, 2022
This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021.

Off-Belief Learning Introduction This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021. Environment Setup

Facebook Research 32 Jan 05, 2023
Python scripts for performing 3D human pose estimation using the Mobile Human Pose model in ONNX.

Python scripts for performing 3D human pose estimation using the Mobile Human Pose model in ONNX.

Ibai Gorordo 99 Dec 31, 2022
Adabelief-Optimizer - Repository for NeurIPS 2020 Spotlight "AdaBelief Optimizer: Adapting stepsizes by the belief in observed gradients"

AdaBelief Optimizer NeurIPS 2020 Spotlight, trains fast as Adam, generalizes well as SGD, and is stable to train GANs. Release of package We have rele

Juntang Zhuang 998 Dec 29, 2022
A simple baseline for the 2022 IEEE GRSS Data Fusion Contest (DFC2022)

DFC2022 Baseline A simple baseline for the 2022 IEEE GRSS Data Fusion Contest (DFC2022) This repository uses TorchGeo, PyTorch Lightning, and Segmenta

isaac 24 Nov 28, 2022
A scientific and useful toolbox, which contains practical and effective long-tail related tricks with extensive experimental results

Bag of tricks for long-tailed visual recognition with deep convolutional neural networks This repository is the official PyTorch implementation of AAA

Yong-Shun Zhang 181 Dec 28, 2022
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022