This package contains a PyTorch Implementation of IB-GAN of the submitted paper in AAAI 2021

Related tags

Deep LearningIB-GAN
Overview

The PyTorch implementation of IB-GAN model of AAAI 2021

This package contains a PyTorch implementation of IB-GAN presented in the submitted paper (IB-GAN: Disentangled Representation Learning with Information Bottleneck Generative Adversarial Networks) in AAAI 2021.

You can reproduce the experiment on dSprite (Color-dSprite, 3DChairs, and CelebA) dataset with the this code.

Current implementation is based on python==1.4.0. Please refer environments.yml for the environment settings.

Please refer to the Technical appendix page for more detailed information of hypter parameter settings for each experiment.

Contents

  • Main code for dsprites (and cdsprite): "main.py"

  • IB-GAN model for dsprites (and cdsprite): "./model/model.py"

  • Disentanglement Evaluation codes for dsprites (and cdsprite): "evaluator.py", "checkout_scores.ipynb"

  • Main code for 3d Chairs (and CelebA): "main2.py"

  • IB-GAN model for dsprites (and cdsprite): "./model/model2.py"

Visdom for visualization

Since the defulat visidom option for main.py is True, you first want to run Visidom server berfore excuting the main program by typing

python -m visdom.server -p 8097

Then you can observe the visualization of the "convergence plot and generated samples" for each training iterations from

localhost:8097

Reproducing dSprite experiment

  • dSprite dataset : "./data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"

You can reproduce dSprite expreiment by typing:

python -W ignore main.py --seed 7 --z_dim 16 --r_dim 10 --batch_size 64 --optim rmsprop --dataset dsprites --viz True --viz_port 8097 --z_bias 0 --viz_name dsprites --beta 0.141 --alpha 1 --gamma 1 --G_lr 5e-5 --D_lr 1e-6 --max_iter 150000 --logiter 500 --ptriter 2500 --ckptiter 2500 --load_ckpt -1 --init_type normal --save_img True

Note, all the default parameter settings are optimally set up for the dSprite experiment (in the "main.py" file). For more details on the parameter settings for other datasets, please refer to the Technical appendix.

  • dSprite dataset for Kim's disentanglement score evaluation : Evauation file is currently not available. (will be update soon) The evaulation process and code is same as cdsprite experiment.

Reproducing Color-dSprite expreiemnt

  • Color-dSprite dataset : Color dSprite Dataset is currently not available.

But you can create Colored-dSprites dataset by changing RGB channel of the original dsprites dataset.

Each channel of RGB takes 8 discrete values as : [0.00, 36.42, 72.85, 109.28, 145.71, 182.14, 218.57, 255.00] )

Then move Color-dSprites datset (eg. cdsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz) npz file to the folder (./data/dsprites-dataset/)

Run the code with following argument:

python -W ignore main.py --seed 7 --z_dim 16 --r_dim 10 --batch_size 64 --optim rmsprop --dataset cdsprites --viz True --viz_port 8097 --z_bias 0 --viz_name dsprites --beta 0.071 --alpha 1 --gamma 1 --G_lr 5e-5 --D_lr 1e-6 --max_iter 500000 --logiter 500 --ptriter 2500 --ckptiter 2500 --load_ckpt -1 --init_type normal --save_img True
  • Color-dSprite dataset for Kim's disentanglement score evaluation : "./data/img4eval_cdsprites.7z".

You first need to unzip "imgs4eval_cdsprites.7z" file using 7za. Please locate all the unzip files in "/data/imgs4eval_cdsprites/*" folder.

run the evaluation on Kim's disentanglment metric, type

python evaluator.py --dset_dir data/imgs4eval_cdsprites --logiter 5000 --lastiter 500000 --name main

After all the evaluations for each checkpoint is done, you can see the overall disentanglement scores with the "checkout_scores.ipynb" (jupyter notebook) file. or you can just type

import os
import torch
torch.load('checkpoint/main/result.metric')

to see the scores in the python console. Then move Color-dSprites datset (eg. cdsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz) to ./data/dsprites-dataset/

Reproducing CelebA experiment

  • CelebA dataset : please download CelebA dataset and prepare 64x64 center cropped image files into the folder (./data/CelebA/cropped_64)

Then run the code with following argument:

python -W ignore main2.py --seed 0 --z_dim 64 --r_dim 15 --batch_size 64 --optim rmsprop --dataset celeba --viz_port 8097 --z_bias 0 --r_weight 0 --viz_name celeba --beta 0.35 --alpha 1 --gamma 1 --max_iter 1000000 --G_lr 5e-5 --D_lr 2e-6 --R_lr 5e-5 --ckpt_dir checkpoint --output_dir output --logiter 500 --ptriter 20000 --ckptiter 20000 --ngf 64 --ndf 64 --label_smoothing True --instance_noise_start 0.5 --instance_noise_end 0.01 --init_type orthogonal

Reproducing 3dChairs experiment

  • 3dChairs dataset : please download 3dChairs dataset and move image files into the folder (./data/3DChairs/images)
python -W ignore main2.py --seed 0 --z_dim 64 --r_dim 10 --batch_size 64 --optim rmsprop --dataset 3dchairs --viz_port 8097 --z_bias 0 --r_weight 0 --viz_name 3dchairs --beta 0.325 --alpha 1 --gamma 1 --max_iter 700000 --G_lr 5e-5 --D_lr 2e-6 --R_lr 5e-5 --ckpt_dir checkpoint --output_dir output --logiter 500 --ptriter 20000 --ckptiter 20000 --ngf 32 --ndf 32 --label_smoothing True --instance_noise_start 0.5 --instance_noise_end 0.01 --init_type orthogonal

Citing IB-GAN

If you like this work and end up using IB-GAN for your reseach, please cite our paper with the bibtex code:

@inproceedings{jeon2021ib, title={IB-GAN: Disengangled Representation Learning with Information Bottleneck Generative Adversarial Networks}, author={Jeon, Insu and Lee, Wonkwang and Pyeon, Myeongjang and Kim, Gunhee}, booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, volume={35}, number={9}, pages={7926--7934}, year={2021} }

The disclosure and use of the currently published code is limited to research purposes only.

Owner
Insu Jeon
Stay hungry, stay foolish.
Insu Jeon
Practical Single-Image Super-Resolution Using Look-Up Table

Practical Single-Image Super-Resolution Using Look-Up Table [Paper] Dependency Python 3.6 PyTorch glob numpy pillow tqdm tensorboardx 1. Training deep

Younghyun Jo 116 Dec 23, 2022
This project intends to use SVM supervised learning to determine whether or not an individual is diabetic given certain attributes.

Diabetes Prediction Using SVM I explore a diabetes prediction algorithm using a Diabetes dataset. Using a Support Vector Machine for my prediction alg

Jeff Shen 1 Jan 14, 2022
SNE-RoadSeg in PyTorch, ECCV 2020

SNE-RoadSeg Introduction This is the official PyTorch implementation of SNE-RoadSeg: Incorporating Surface Normal Information into Semantic Segmentati

242 Dec 20, 2022
Black-Box-Tuning - Black-Box Tuning for Language-Model-as-a-Service

Black-Box-Tuning Source code for paper "Black-Box Tuning for Language-Model-as-a-Service". Being busy recently, the code in this repo and this tutoria

Tianxiang Sun 149 Jan 04, 2023
WarpRNNT loss ported in Numba CPU/CUDA for Pytorch

RNNT loss in Pytorch - Numba JIT compiled (warprnnt_numba) Warp RNN Transducer Loss for ASR in Pytorch, ported from HawkAaron/warp-transducer and a re

Somshubra Majumdar 15 Oct 22, 2022
Tutorial to set up TensorFlow Object Detection API on the Raspberry Pi

A tutorial showing how to set up TensorFlow's Object Detection API on the Raspberry Pi

Evan 1.1k Dec 26, 2022
Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras

Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras This tutorial shows how to use Keras library to build deep ne

Marko Jocić 922 Dec 19, 2022
Official code for paper "ISNet: Costless and Implicit Image Segmentation for Deep Classifiers, with Application in COVID-19 Detection"

Official code for paper "ISNet: Costless and Implicit Image Segmentation for Deep Classifiers, with Application in COVID-19 Detection". LRPDenseNet.py

Pedro Ricardo Ariel Salvador Bassi 2 Sep 21, 2022
A curated list of awesome projects and resources related fastai

A curated list of awesome projects and resources related fastai

Tanishq Abraham 138 Dec 22, 2022
The all new way to turn your boring vector meshes into the new fad in town; Voxels!

Voxelator The all new way to turn your boring vector meshes into the new fad in town; Voxels! Notes: I have not tested this on a rotated mesh. With fu

6 Feb 03, 2022
Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021).

STAR-pytorch Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021). CVF (pdf) STAR-DC

43 Dec 21, 2022
Learnable Boundary Guided Adversarial Training (ICCV2021)

Learnable Boundary Guided Adversarial Training This repository contains the implementation code for the ICCV2021 paper: Learnable Boundary Guided Adve

DV Lab 27 Sep 25, 2022
Implementation for the paper: Invertible Denoising Network: A Light Solution for Real Noise Removal (CVPR2021).

Invertible Image Denoising This is the PyTorch implementation of paper: Invertible Denoising Network: A Light Solution for Real Noise Removal (CVPR 20

157 Dec 25, 2022
Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation

Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation [Arxiv] [Video] Evaluation code for Unrestricted Facial Geometry Reconstr

Matan Sela 242 Dec 30, 2022
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 06, 2022
Focal and Global Knowledge Distillation for Detectors

FGD Paper: Focal and Global Knowledge Distillation for Detectors Install MMDetection and MS COCO2017 Our codes are based on MMDetection. Please follow

Mesopotamia 261 Dec 23, 2022
Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor.

Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor. It is devel

33 Nov 11, 2022
This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust.

Demo BERT ONNX pipeline written in rust This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust. R

Xavier Tao 14 Dec 17, 2022
Official implementation of Monocular Quasi-Dense 3D Object Tracking

Monocular Quasi-Dense 3D Object Tracking Monocular Quasi-Dense 3D Object Tracking (QD-3DT) is an online framework detects and tracks objects in 3D usi

Visual Intelligence and Systems Group 441 Dec 20, 2022