Adversarial-autoencoders - Tensorflow implementation of Adversarial Autoencoders

Overview

Adversarial Autoencoders (AAE)

  • Tensorflow implementation of Adversarial Autoencoders (ICLR 2016)
  • Similar to variational autoencoder (VAE), AAE imposes a prior on the latent variable z. Howerver, instead of maximizing the evidence lower bound (ELBO) like VAE, AAE utilizes a adversarial network structure to guides the model distribution of z to match the prior distribution.
  • This repository contains reproduce of several experiments mentioned in the paper.

Requirements

Implementation details

  • All the models of AAE are defined in src/models/aae.py.
  • Model corresponds to fig 1 and 3 in the paper can be found here: train and test.
  • Model corresponds to fig 6 in the paper can be found here: train and test.
  • Model corresponds to fig 8 in the paper can be found here: train and test.
  • Examples of how to use AAE models can be found in experiment/aae_mnist.py.
  • Encoder, decoder and all discriminators contain two fully connected layers with 1000 hidden units and RelU activation function. Decoder and all discriminators contain an additional fully connected layer for output.
  • Images are normalized to [-1, 1] before fed into the encoder and tanh is used as the output nonlinear of decoder.
  • All the sub-networks are optimized by Adam optimizer with beta1 = 0.5.

Preparation

  • Download the MNIST dataset from here.
  • Setup path in experiment/aae_mnist.py: DATA_PATH is the path to put MNIST dataset. SAVE_PATH is the path to save output images and trained model.

Usage

The script experiment/aae_mnist.py contains all the experiments shown here. Detailed usage for each experiment will be describe later along with the results.

Argument

  • --train: Train the model of Fig 1 and 3 in the paper.
  • --train_supervised: Train the model of Fig 6 in the paper.
  • --train_semisupervised: Train the model of Fig 8 in the paper.
  • --label: Incorporate label information in the adversarial regularization (Fig 3 in the paper).
  • --generate: Randomly sample images from trained model.
  • --viz: Visualize latent space and data manifold (only when --ncode is 2).
  • --supervise: Sampling from supervised model (Fig 6 in the paper) when --generate is True.
  • --load: The epoch ID of pre-trained model to be restored.
  • --ncode: Dimension of code. Default: 2
  • --dist_type: Type of the prior distribution used to impose on the hidden codes. Default: gaussian. gmm for Gaussian mixture distribution.
  • --noise: Add noise to encoder input (Gaussian with std=0.6).
  • --lr: Initial learning rate. Default: 2e-4.
  • --dropout: Keep probability for dropout. Default: 1.0.
  • --bsize: Batch size. Default: 128.
  • --maxepoch: Max number of epochs. Default: 100.
  • --encw: Weight of autoencoder loss. Default: 1.0.
  • --genw: Weight of z generator loss. Default: 6.0.
  • --disw: Weight of z discriminator loss. Default: 6.0.
  • --clsw: Weight of semi-supervised loss. Default: 1.0.
  • --ygenw: Weight of y generator loss. Default: 6.0.
  • --ydisw: Weight of y discriminator loss. Default: 6.0.

1. Adversarial Autoencoder

Architecture

Architecture Description
The top row is an autoencoder. z is sampled through the re-parameterization trick discussed in variational autoencoder paper. The bottom row is a discriminator to separate samples generate from the encoder and samples from the prior distribution p(z).

Hyperparameters

name value
Reconstruction Loss Weight 1.0
Latent z G/D Loss Weight 6.0 / 6.0
Batch Size 128
Max Epoch 400
Learning Rate 2e-4 (initial) / 2e-5 (100 epochs) / 2e-6 (300 epochs)

Usage

  • Training. Summary, randomly sampled images and latent space during training will be saved in SAVE_PATH.
python aae_mnist.py --train \
  --ncode CODE_DIM \
  --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`)
  • Random sample data from trained model. Image will be saved in SAVE_PATH with name generate_im.png.
python aae_mnist.py --generate \
  --ncode CODE_DIM \
  --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`)\
  --load RESTORE_MODEL_ID
  • Visualize latent space and data manifold (only when code dim = 2). Image will be saved in SAVE_PATH with name generate_im.png and latent.png. For Gaussian distribution, there will be one image for data manifold. For mixture of 10 2D Gaussian, there will be 10 images of data manifold for each component of the distribution.
python aae_mnist.py --viz \
  --ncode CODE_DIM \
  --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`)\
  --load RESTORE_MODEL_ID

Result

  • For 2D Gaussian, we can see sharp transitions (no gaps) as mentioned in the paper. Also, from the learned manifold, we can see almost all the sampled images are readable.
  • For mixture of 10 Gaussian, I just uniformly sample images in a 2D square space as I did for 2D Gaussian instead of sampling along the axes of the corresponding mixture component, which will be shown in the next section. We can see in the gap area between two component, it is less likely to generate good samples.
Prior Distribution Learned Coding Space Learned Manifold

2. Incorporating label in the Adversarial Regularization

Architecture

Architecture Description
The only difference from previous model is that the one-hot label is used as input of encoder and there is one extra class for unlabeled data. For mixture of Gaussian prior, real samples are drawn from each components for each labeled class and for unlabeled data, real samples are drawn from the mixture distribution.

Hyperparameters

Hyperparameters are the same as previous section.

Usage

  • Training. Summary, randomly sampled images and latent space will be saved in SAVE_PATH.
python aae_mnist.py --train --label\
  --ncode CODE_DIM \
  --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`)
  • Random sample data from trained model. Image will be saved in SAVE_PATH with name generate_im.png.
python aae_mnist.py --generate --ncode 
   
     --label --dist_type 
    
      --load 
     

     
    
   
  • Visualize latent space and data manifold (only when code dim = 2). Image will be saved in SAVE_PATH with name generate_im.png and latent.png. For Gaussian distribution, there will be one image for data manifold. For mixture of 10 2D Gaussian, there will be 10 images of data manifold for each component of the distribution.
python aae_mnist.py --viz --label \
  --ncode CODE_DIM \
  --dist_type TYPE_OF_PRIOR (`gaussian` or `gmm`) \
  --load RESTORE_MODEL_ID

Result

  • Compare with the result in the previous section, incorporating labeling information provides better fitted distribution for codes.
  • The learned manifold images demonstrate that each Gaussian component corresponds to the one class of digit. However, the style representation is not consistently represented within each mixture component as shown in the paper. For example, the right most column of the first row experiment, the lower right of digit 1 tilt to left while the lower right of digit 9 tilt to right.
Number of Label Used Learned Coding Space Learned Manifold
Use full label
10k labeled data and 40k unlabeled data

3. Supervised Adversarial Autoencoders

Architecture

Architecture Description
The decoder takes code as well as a one-hot vector encoding the label as input. Then it forces the network learn the code independent of the label.

Hyperparameters

Usage

  • Training. Summary and randomly sampled images will be saved in SAVE_PATH.
python aae_mnist.py --train_supervised \
  --ncode CODE_DIM
  • Random sample data from trained model. Image will be saved in SAVE_PATH with name sample_style.png.
python aae_mnist.py  --generate --supervise\
  --ncode CODE_DIM \
  --load RESTORE_MODEL_ID

Result

  • The result images are generated by using the same code for each column and the same digit label for each row.
  • When code dimension is 2, we can see each column consists the same style clearly. But for dimension 10, we can hardly read some digits. Maybe there are some issues of implementation or the hyper-parameters are not properly picked, which makes the code still depend on the label.
Code Dim=2 Code Dim=10

4. Semi-supervised learning

Architecture

Architecture Description
The encoder outputs code z as well as the estimated label y. Encoder again takes code z and one-hot label y as input. A Gaussian distribution is imposed on code z and a Categorical distribution is imposed on label y. In this implementation, the autoencoder is trained by semi-supervised classification phase every ten training steps when using 1000 label images and the one-hot label y is approximated by output of softmax.

Hyperparameters

name value
Dimention of z 10
Reconstruction Loss Weight 1.0
Letant z G/D Loss Weight 6.0 / 6.0
Letant y G/D Loss Weight 6.0 / 6.0
Batch Size 128
Max Epoch 250
Learning Rate 1e-4 (initial) / 1e-5 (150 epochs) / 1e-6 (200 epochs)

Usage

  • Training. Summary will be saved in SAVE_PATH.
python aae_mnist.py \
  --ncode 10 \
  --train_semisupervised \
  --lr 2e-4 \
  --maxepoch 250

Result

  • 1280 labels are used (128 labeled images per class)

learning curve for training set (computed only on the training set with labels) train

learning curve for testing set

  • The accuracy on testing set is 97.10% around 200 epochs. valid
Owner
Qian Ge
ECE PhD candidate at NCSU
Qian Ge
Autoregressive Models in PyTorch.

Autoregressive This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like auto

Christoph Heindl 41 Oct 09, 2022
Unified MultiWOZ evaluation scripts for the context-to-response task.

MultiWOZ Context-to-Response Evaluation Standardized and easy to use Inform, Success, BLEU ~ See the paper ~ Easy-to-use scripts for standardized eval

Tomáš Nekvinda 38 Dec 13, 2022
Official code of paper: MovingFashion: a Benchmark for the Video-to-Shop Challenge

SEAM Match-RCNN Official code of MovingFashion: a Benchmark for the Video-to-Shop Challenge paper Installation Requirements: Pytorch 1.5.1 or more rec

HumaticsLAB 31 Oct 10, 2022
Use VITS and Opencpop to develop singing voice synthesis; Maybe it will VISinger.

Init Use VITS and Opencpop to develop singing voice synthesis; Maybe it will VISinger. 本项目基于 https://github.com/jaywalnut310/vits https://github.com/S

AmorTX 107 Dec 23, 2022
A PyTorch version of You Only Look at One-level Feature object detector

PyTorch_YOLOF A PyTorch version of You Only Look at One-level Feature object detector. The input image must be resized to have their shorter side bein

Jianhua Yang 25 Dec 30, 2022
Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically.

Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically. The collected data will then be used to train a deep neural network that can

Martin Valchev 3 Apr 24, 2022
PyTorch implementation of CloudWalk's recent work DenseBody

densebody_pytorch PyTorch implementation of CloudWalk's recent paper DenseBody. Note: For most recent updates, please check out the dev branch. Update

Lingbo Yang 401 Nov 19, 2022
Back to the Feature: Learning Robust Camera Localization from Pixels to Pose (CVPR 2021)

Back to the Feature with PixLoc We introduce PixLoc, a neural network for end-to-end learning of camera localization from an image and a 3D model via

Computer Vision and Geometry Lab 610 Jan 05, 2023
CONditionals for Ordinal Regression and classification in tensorflow

Condor Ordinal regression in Tensorflow Keras Tensorflow Keras implementation of CONDOR Ordinal Regression (aka ordinal classification) by Garrett Jen

9 Jul 31, 2022
Clean Machine Learning, a Coding Kata

Kata: Clean Machine Learning From Dirty Code First, open the Kata in Google Colab (or else download it) You can clone this project and launch jupyter-

Neuraxio 13 Nov 03, 2022
Open source implementation of AceNAS: Learning to Rank Ace Neural Architectures with Weak Supervision of Weight Sharing

AceNAS This repo is the experiment code of AceNAS, and is not considered as an official release. We are working on integrating AceNAS as a built-in st

Yuge Zhang 6 Sep 07, 2022
Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Beijing ColorfulClouds Technology Co.,Ltd. 16 Aug 07, 2022
BLEND: A Fast, Memory-Efficient, and Accurate Mechanism to Find Fuzzy Seed Matches

BLEND is a mechanism that can efficiently find fuzzy seed matches between sequences to significantly improve the performance and accuracy while reducing the memory space usage of two important applic

SAFARI Research Group at ETH Zurich and Carnegie Mellon University 19 Dec 26, 2022
某学校选课系统GIF验证码数据集 + Baseline模型 + 上下游相关工具

elective-dataset-2021spring 某学校2021春季选课系统GIF验证码数据集(29338张) + 准确率98.4%的Baseline模型 + 上下游相关工具。 数据集采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。 Baseline模型和上下游相关工具采用

xmcp 27 Sep 17, 2021
Video-based open-world segmentation

UVO_Challenge Team Alpes_runner Solutions This is an official repo for our UVO Challenge solutions for Image/Video-based open-world segmentation. Our

Yuming Du 84 Dec 22, 2022
Make your own game in a font!

Project structure. Included is a suite of tools to create font games. Tutorial: For a quick tutorial about how to make your own game go here For devel

Michael Mulet 125 Dec 04, 2022
Yggdrasil - A simplistic bot designed to streamline your server experience

Ygggdrasil A simplistic bot designed to streamline your server experience. Desig

Sntx_ 1 Dec 14, 2022
Turning pixels into virtual points for multimodal 3D object detection.

Multimodal Virtual Point 3D Detection Turning pixels into virtual points for multimodal 3D object detection. Multimodal Virtual Point 3D Detection, Ti

Tianwei Yin 204 Jan 08, 2023
Converts geometry node attributes to built-in attributes

Attribute Converter Simplifies converting attributes created by geometry nodes to built-in attributes like UVs or vertex colors, as a single click ope

Ivan Notaros 12 Dec 22, 2022
Progressive Growing of GANs for Improved Quality, Stability, and Variation

Progressive Growing of GANs for Improved Quality, Stability, and Variation — Official TensorFlow implementation of the ICLR 2018 paper Tero Karras (NV

Tero Karras 5.9k Jan 05, 2023