How to Train a GAN? Tips and tricks to make GANs work

Related tags

Deep Learningganhacks
Overview

(this list is no longer maintained, and I am not sure how relevant it is in 2020)

How to Train a GAN? Tips and tricks to make GANs work

While research in Generative Adversarial Networks (GANs) continues to improve the fundamental stability of these models, we use a bunch of tricks to train them and make them stable day to day.

Here are a summary of some of the tricks.

Here's a link to the authors of this document

If you find a trick that is particularly useful in practice, please open a Pull Request to add it to the document. If we find it to be reasonable and verified, we will merge it in.

1. Normalize the inputs

  • normalize the images between -1 and 1
  • Tanh as the last layer of the generator output

2: A modified loss function

In GAN papers, the loss function to optimize G is min (log 1-D), but in practice folks practically use max log D

  • because the first formulation has vanishing gradients early on
  • Goodfellow et. al (2014)

In practice, works well:

  • Flip labels when training generator: real = fake, fake = real

3: Use a spherical Z

  • Dont sample from a Uniform distribution

cube

  • Sample from a gaussian distribution

sphere

4: BatchNorm

  • Construct different mini-batches for real and fake, i.e. each mini-batch needs to contain only all real images or all generated images.
  • when batchnorm is not an option use instance normalization (for each sample, subtract mean and divide by standard deviation).

batchmix

5: Avoid Sparse Gradients: ReLU, MaxPool

  • the stability of the GAN game suffers if you have sparse gradients
  • LeakyReLU = good (in both G and D)
  • For Downsampling, use: Average Pooling, Conv2d + stride
  • For Upsampling, use: PixelShuffle, ConvTranspose2d + stride

6: Use Soft and Noisy Labels

  • Label Smoothing, i.e. if you have two target labels: Real=1 and Fake=0, then for each incoming sample, if it is real, then replace the label with a random number between 0.7 and 1.2, and if it is a fake sample, replace it with 0.0 and 0.3 (for example).
    • Salimans et. al. 2016
  • make the labels the noisy for the discriminator: occasionally flip the labels when training the discriminator

7: DCGAN / Hybrid Models

  • Use DCGAN when you can. It works!
  • if you cant use DCGANs and no model is stable, use a hybrid model : KL + GAN or VAE + GAN

8: Use stability tricks from RL

  • Experience Replay
    • Keep a replay buffer of past generations and occassionally show them
    • Keep checkpoints from the past of G and D and occassionaly swap them out for a few iterations
  • All stability tricks that work for deep deterministic policy gradients
  • See Pfau & Vinyals (2016)

9: Use the ADAM Optimizer

  • optim.Adam rules!
    • See Radford et. al. 2015
  • Use SGD for discriminator and ADAM for generator

10: Track failures early

  • D loss goes to 0: failure mode
  • check norms of gradients: if they are over 100 things are screwing up
  • when things are working, D loss has low variance and goes down over time vs having huge variance and spiking
  • if loss of generator steadily decreases, then it's fooling D with garbage (says martin)

11: Dont balance loss via statistics (unless you have a good reason to)

  • Dont try to find a (number of G / number of D) schedule to uncollapse training
  • It's hard and we've all tried it.
  • If you do try it, have a principled approach to it, rather than intuition

For example

while lossD > A:
  train D
while lossG > B:
  train G

12: If you have labels, use them

  • if you have labels available, training the discriminator to also classify the samples: auxillary GANs

13: Add noise to inputs, decay over time

14: [notsure] Train discriminator more (sometimes)

  • especially when you have noise
  • hard to find a schedule of number of D iterations vs G iterations

15: [notsure] Batch Discrimination

  • Mixed results

16: Discrete variables in Conditional GANs

  • Use an Embedding layer
  • Add as additional channels to images
  • Keep embedding dimensionality low and upsample to match image channel size

17: Use Dropouts in G in both train and test phase

Authors

  • Soumith Chintala
  • Emily Denton
  • Martin Arjovsky
  • Michael Mathieu
Owner
Soumith Chintala
/\︿╱\ _________________________________ \0_ 0 /╱\╱____________________________ \▁︹_/
Soumith Chintala
Abstractive opinion summarization system (SelSum) and the largest dataset of Amazon product summaries (AmaSum). EMNLP 2021 conference paper.

Learning Opinion Summarizers by Selecting Informative Reviews This repository contains the codebase and the dataset for the corresponding EMNLP 2021

Arthur Bražinskas 39 Jan 01, 2023
An SE(3)-invariant autoencoder for generating the periodic structure of materials

Crystal Diffusion Variational AutoEncoder This software implementes Crystal Diffusion Variational AutoEncoder (CDVAE), which generates the periodic st

Tian Xie 94 Dec 10, 2022
Sign Language is detected in realtime using video sequences. Our approach involves MediaPipe Holistic for keypoints extraction and LSTM Model for prediction.

RealTime Sign Language Detection using Action Recognition Approach Real-Time Sign Language is commonly predicted using models whose architecture consi

Rishikesh S 15 Aug 20, 2022
A demo of how to use JAX to create a simple gravity simulation

JAX Gravity This repo contains a demo of how to use JAX to create a simple gravity simulation. It uses JAX's experimental ode package to solve the dif

Cristian Garcia 16 Sep 22, 2022
Relative Positional Encoding for Transformers with Linear Complexity

Stochastic Positional Encoding (SPE) This is the source code repository for the ICML 2021 paper Relative Positional Encoding for Transformers with Lin

Antoine Liutkus 48 Nov 16, 2022
DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo.

dm_control: DeepMind Infrastructure for Physics-Based Simulation. DeepMind's software stack for physics-based simulation and Reinforcement Learning en

DeepMind 3k Dec 31, 2022
A-SDF: Learning Disentangled Signed Distance Functions for Articulated Shape Representation (ICCV 2021)

A-SDF: Learning Disentangled Signed Distance Functions for Articulated Shape Representation (ICCV 2021) This repository contains the official implemen

81 Dec 14, 2022
A simplistic and efficient pure-python neural network library from Phys Whiz with CPU and GPU support.

A simplistic and efficient pure-python neural network library from Phys Whiz with CPU and GPU support.

Manas Sharma 19 Feb 28, 2022
A package for "Procedural Content Generation via Reinforcement Learning" OpenAI Gym interface.

Readme: Illuminating Diverse Neural Cellular Automata for Level Generation This is the codebase used to generate the results presented in the paper av

Sam Earle 27 Jan 05, 2023
HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty

HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty Giorgio Cantarini, Francesca Odone, Nicoletta Noceti, Federi

18 Aug 02, 2022
Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Twins: Revisiting the Design of Spatial Attention in Vision Transformers Very recently, a variety of vision transformer architectures for dense predic

482 Dec 18, 2022
Genetic feature selection module for scikit-learn

sklearn-genetic Genetic feature selection module for scikit-learn Genetic algorithms mimic the process of natural selection to search for optimal valu

Manuel Calzolari 260 Dec 14, 2022
Official Implementation of Neural Splines

Neural Splines: Fitting 3D Surfaces with Inifinitely-Wide Neural Networks This repository contains the official implementation of the CVPR 2021 (Oral)

Francis Williams 56 Nov 29, 2022
[CVPR2021] The source code for our paper 《Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Learning》.

TBE The source code for our paper "Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Le

Jinpeng Wang 150 Dec 28, 2022
Official repository for the CVPR 2021 paper "Learning Feature Aggregation for Deep 3D Morphable Models"

Deep3DMM Official repository for the CVPR 2021 paper Learning Feature Aggregation for Deep 3D Morphable Models. Requirements This code is tested on Py

38 Dec 27, 2022
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
PyTorch implementation of our ICCV 2021 paper Intrinsic-Extrinsic Preserved GANs for Unsupervised 3D Pose Transfer.

Unsupervised_IEPGAN This is the PyTorch implementation of our ICCV 2021 paper Intrinsic-Extrinsic Preserved GANs for Unsupervised 3D Pose Transfer. Ha

25 Oct 26, 2022
A map update dataset and benchmark

MUNO21 MUNO21 is a dataset and benchmark for machine learning methods that automatically update and maintain digital street map datasets. Previous dat

16 Nov 30, 2022
Graph-based community clustering approach to extract protein domains from a predicted aligned error matrix

Using a predicted aligned error matrix corresponding to an AlphaFold2 model , returns a series of lists of residue indices, where each list corresponds to a set of residues clustering together into a

Tristan Croll 24 Nov 23, 2022
This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in Eurographics 2021

Deep-Detail-Enhancement-for-Any-Garment Introduction This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in

40 Dec 13, 2022