Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Overview

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

License: MIT

Code for this paper Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly. [Preprint]

Tianlong Chen, Yu Cheng, Zhe Gan, Jingjing Liu, Zhangyang Wang.

Overview

Training generative adversarial networks (GANs) with limited data generally results in deteriorated performance and collapsed models. To conquerthis challenge, we are inspired by the latest observation of Kalibhat et al. (2020); Chen et al.(2021d), that one can discover independently trainable and highly sparse subnetworks (a.k.a.,lottery tickets) from GANs. Treating this as aninductive prior, we decompose the data-hungry GAN training into two sequential sub-problems:

  • (i) identifying the lottery ticket from the original GAN;
  • (ii) then training the found sparse subnetwork with aggressive data and feature augmentations.

Both sub-problems re-use the same small training set of real images. Such a coordinated framework enables us to focus on lower-complexity and more data-efficient sub-problems, effectively stabilizing trainingand improving convergence.

Methodology

Experiment Results

More experiments can be found in our paper.

Implementation

For the first step, finding the lottery tickets in GAN is referred to this repo.

For the second step, training GAN ticket toughly are provides as follow:

Environment for SNGAN

conda install python3.6
conda install pytorch1.4.0 -c pytorch
pip install tensorflow-gpu==1.13
pip install imageio
pip install tensorboardx

R.K. Donwload fid statistics from Fid_Stat.

Commands for SNGAN

R.K. Limited data training for SNGAN

  • Dataset: CIFAR-10

Example for full model training on 20% limited data (--ratio 0.2):

python train_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --ratio 0.2

Example for full model training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2

Example for sparse model (i.e., GAN tickets) training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_with_masks_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2 --rewind-path <>
  • --rewind-path: the stored path of identified sparse masks

Environment for BigGAN

conda env create -f environment.yml studiogan

Commands for BigGAN

R.K. Limited data training for BigGAN

  • Dataset: TINY ILSVRC

Example:

python main_ompg.py -t -e -c ./configs/TINY_ILSVRC2012/BigGAN_adv.json --eval_type valid --seed 42 --mask_path checkpoints/BigGAN-train-0.1 --mask_round 2 --reduce_train_dataset 0.1 --gamma 0.01 
  • --mask_path: the stored path of identified sparse masks
  • --mask_round: the sparsity level = 0.8 ^ mask_round
  • --reduce_train_dataset: the size of used limited training data
  • --gamma: hyperparameter for AdvAug. You can set it to 0 to git rid of AdvAug

  • Dataset: CIFAR100

Example:

python main_ompg.py -t -e -c ./configs/CIFAR100_less/DiffAugGAN_adv.json --ratio 0.2 --mask_path checkpoints/diffauggan_cifar100_0.2 --mask_round 9 --seed 42 --gamma 0.01
  • DiffAugGAN_adv.json: it indicate this confirguration use DiffAug.

Pre-trained Models

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights

https://www.dropbox.com/sh/7v8hn2859cvm7jj/AACyN8FOkMjgMwy5ibVj61IPa?dl=0

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/gsklrdcjzogqzcd/AAALlIYcWOZuERLcocKIqlEya?dl=0

  • BigGAN / CIFAR-10 / 10% Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/epuajb1iqn5xma6/AAAD0zwehky1wvV3M3-uesHsa?dl=0

  • BigGAN / CIFAR-100 10% / Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/y3pqdqee39jpct4/AAAsSebqHwkWmjO_O8Hp0hcEa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model

https://www.dropbox.com/sh/2rmvqwgcjir1p2l/AABNEo0B-0V9ZSnLnKF_OUA3a?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model + AdvAug on G and D

https://www.dropbox.com/sh/pbwjphualzdy2oe/AACZ7VYJctNBKz3E9b8fgj_Ia?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights

https://www.dropbox.com/sh/82i9z44uuczj3u3/AAARsfNzOgd1R9sKuh1OqUdoa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/yknk1joigx0ufbo/AAChMvzCsedejFjY1XxGcaUta?dl=0

Citation

@misc{chen2021ultradataefficient,
      title={Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly}, 
      author={Tianlong Chen and Yu Cheng and Zhe Gan and Jingjing Liu and Zhangyang Wang},
      year={2021},
      eprint={2103.00397},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Acknowledgement

https://github.com/VITA-Group/GAN-LTH

https://github.com/GongXinyuu/sngan.pytorch

https://github.com/VITA-Group/AutoGAN

https://github.com/POSTECH-CVLab/PyTorch-StudioGAN

https://github.com/mit-han-lab/data-efficient-gans

https://github.com/lucidrains/stylegan2-pytorch

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
A curated list of programmatic weak supervision papers and resources

A curated list of programmatic weak supervision papers and resources

Jieyu Zhang 118 Jan 02, 2023
My coursework for Machine Learning (2021 Spring) at National Taiwan University (NTU)

Machine Learning 2021 Machine Learning (NTU EE 5184, Spring 2021) Instructor: Hung-yi Lee Course Website : (https://speech.ee.ntu.edu.tw/~hylee/ml/202

100 Dec 26, 2022
LieTransformer: Equivariant Self-Attention for Lie Groups

LieTransformer This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant Self-At

OxCSML (Oxford Computational Statistics and Machine Learning) 50 Dec 28, 2022
Happywhale - Whale and Dolphin Identification Silver🥈 Solution (26/1588)

Kaggle-Happywhale Happywhale - Whale and Dolphin Identification Silver 🥈 Solution (26/1588) 竞赛方案思路 图像数据预处理-标志性特征图片裁剪:首先根据开源的标注数据训练YOLOv5x6目标检测模型,将训练集

Franxx 20 Nov 14, 2022
Official Implementation of Swapping Autoencoder for Deep Image Manipulation (NeurIPS 2020)

Swapping Autoencoder for Deep Image Manipulation Taesung Park, Jun-Yan Zhu, Oliver Wang, Jingwan Lu, Eli Shechtman, Alexei A. Efros, Richard Zhang UC

449 Dec 27, 2022
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022
Simple ONNX operation generator. Simple Operation Generator for ONNX.

sog4onnx Simple ONNX operation generator. Simple Operation Generator for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools Key concept V

Katsuya Hyodo 6 May 15, 2022
Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet

Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet, CVPR2021 安全AI挑战者计划第六期:ImageNet无限制对抗攻击 决赛第四名(team name: Advers)

51 Dec 01, 2022
A variational Bayesian method for similarity learning in non-rigid image registration (CVPR 2022)

A variational Bayesian method for similarity learning in non-rigid image registration We provide the source code and the trained models used in the re

daniel grzech 14 Nov 21, 2022
【Arxiv】Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution

SANet Exploring Separable Attention for Multi-Contrast MR Image Super-Resolution Dependencies numpy==1.18.5 scikit_image==0.16.2 torchvision==0.8.1 to

36 Jan 05, 2023
JittorVis - Visual understanding of deep learning models

JittorVis: Visual understanding of deep learning model JittorVis is an open-source library for understanding the inner workings of Jittor models by vi

thu-vis 182 Jan 06, 2023
A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

CapsNet-Tensorflow A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules Notes: The current version

Huadong Liao 3.8k Dec 29, 2022
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
[ACM MM2021] MGH: Metadata Guided Hypergraph Modeling for Unsupervised Person Re-identification

Introduction This project is developed based on FastReID, which is an ongoing ReID project. Projects BUC In projects/BUC, we implement AAAI 2019 paper

WuYiming 7 Apr 13, 2022
A list of all named GANs!

The GAN Zoo Every week, new GAN papers are coming out and it's hard to keep track of them all, not to mention the incredibly creative ways in which re

Avinash Hindupur 12.9k Jan 08, 2023
This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive Selective Coding)

HCSC: Hierarchical Contrastive Selective Coding This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive

YUANFAN GUO 111 Dec 20, 2022
Automatic differentiation with weighted finite-state transducers.

GTN: Automatic Differentiation with WFSTs Quickstart | Installation | Documentation What is GTN? GTN is a framework for automatic differentiation with

100 Dec 29, 2022
The official PyTorch code implementation of "Personalized Trajectory Prediction via Distribution Discrimination" in ICCV 2021.

Personalized Trajectory Prediction via Distribution Discrimination (DisDis) The official PyTorch code implementation of "Personalized Trajectory Predi

25 Dec 20, 2022
NDE: Climate Modeling with Neural Diffusion Equation, ICDM'21

Climate Modeling with Neural Diffusion Equation Introduction This is the repository of our accepted ICDM 2021 paper "Climate Modeling with Neural Diff

Jeehyun Hwang 5 Dec 18, 2022
PyTorch implementation for 3D human pose estimation

Towards 3D Human Pose Estimation in the Wild: a Weakly-supervised Approach This repository is the PyTorch implementation for the network presented in:

Xingyi Zhou 579 Dec 22, 2022