GAN example for Keras. Cuz MNIST is too small and there should be something more realistic.

Overview

Keras-GAN-Animeface-Character

GAN example for Keras. Cuz MNIST is too small and there should an example on something more realistic.

Some results

Training for 22 epochs

Youtube Video, click on the image

Training for 22 epochs

Loss graph for 5000 mini-batches

Loss graph

1 mini-batch = 64 images. Dataset = 14490, hence 5000 mini-batches is approximately 22 epochs.

Some outputs of 5000th min-batch

Some ouptputs of 5000th mini-batch

Some training images

Some inputs

Useful resources, before you go on

How to run this example

Setup

  • My environment: Python 3.6 + Keras 2.0.4 + Tensorflow 1.x
    • If you are on Keras 2.0.0, you need to update it otherwise BatchNormalization() will cause bug, saying "you need to pass float to input" or something like that from Tensorflow back end.
  • Use virtualenv to initialize a similar environment (python and dependencies):
pip install virtualenv
virtualenv -p <PATH_TO_BIN_DIR>/python3.6 venv
source venv/bin/activate
pip install -r requirements.txt
  • I HATE making a program that has so many command line parameters to pass. Many of the parameters are there in the scripts. Adjust the script as you need. The "main()" function is at the bottom of the script as people do in C/C++
  • Most global parameters are defined in args.py.
    • They are defined as class variables not instance variables so you may have trouble running/training multiple instances of the GAN with different parameters. (which is very unlikely to happen)
  • Download dataset from http://www.nurs.or.jp/~nagadomi/animeface-character-dataset/
    • Extract it to this directory so that the scipt can find ./animeface-character-dataset/thumb/
    • Any dataset should work in principle but GAN is sensitive to hyperparameters and may not work on yours. I tuned the parameters for animeface-character-dataset.

Preprocessing

  • Run the preprocessing script. It saves training time to resize/scale the input than doing those tasks on the fly in the training loop.
    • ./data.py
    • The image, when loaded from PNG files, the RGB values have [0, 255]. (uint8 type). data.py will collect the images, resize the images to 64x64 and scale the RGB values so that they will be in [-1.0, 1.0] range.
    • Data.py will only sample a subset of the dataset if configured to do so. The size of the subset is determined by dataset_sz defined in args.py
    • The images will be written to data.hdf5.
      • Made it small to verify the training is working.
      • You can increase it but you need to adjust the network sizes accordingly.
    • Again, which files to read is defined in the script at the bottom, not by sys.argv.
  • You need a large enough dataset. Otherwise the discriminator will sort of "memorize" the true data and reject all that's generated.

Training

  • Open gan.py then at the bottom, uncomment train_autoenc() if you wish.
    • This is useful for seeing the generator network's capability to reproduce the input.
    • The auto-encoder will be trained on input images.
    • The output will be blurry, as the auto-encoder having mean-squared-error loss. (This is why GAN got invented in the first place!)
  • To run training, modify main() so that train_gan() is uncommented.
  • The script will dump reals.png and fakes.png every 10 epoch so that you can see how the training is going.
  • The training takes a while. For this example on Anime Face dataset, it took about 10000 mini-batches to get good results.
    • If you see only uniform color or "modern art" until 2000 then the training is not working!
  • The script also dumps weights every 10 batches. Utilize them to save training time. Weights before diverging is preferred :) Uncomment load_weights() in train_gan().

Training tips

What I experienced during my training of GAN.

  • As described in GAN Hacks, discriminator should be ahead of the generator so that the generator can be "guided" by the discriminator.
  • If you look at loss graph at https://github.com/osh/KerasGAN, they had gen loss in range of 2 to 4. Their training worked well. The discriminator loss is low, arond 0.1.
  • You'll need trial and error to get the hyper-pameters right so that the training stays in the stable, balanced zone. That includes learning rate of D and G, momentums, etc.
  • The convergence is quite sensitive with LR, beware!
  • If things go well, the discriminator loss for detecting real/fake = dloss0/dloss1 should be less than or around 0.1, which means it is good at telling whether the input is real or fake.
  • If learning rate is too high, the discriminator will diverge and one of the loss will get high and will not fall. Training fails in this case.
  • If you make LR too small, it will only slow the learning and will not prevent other issues such as oscillation. It only needs to be lower than certain threshold that is data dependent.
  • If adjusting LR doesn't work, it could be lack of complexity in the discriminator layer. Add more layers, or some other parameters. It could be anything :( Good luck!
  • On the other hand, generator loss will be relatively higher than discriminator loss. In this script, it oscillates in range 0.1 to 4.
  • If you see any of the D loss staying > 15 (when batch size is 32) the training is screwed.
  • In case of G loss > 15, see if it escapes within 30 batches. If it stays there for too long, it isn't good, I think.
  • In case you're seeing high G loss, it could mean it can't keep up with discriminator. You might need to increase LR. (Must be slower than discriminator though)
  • One final piece of the training I was missing was the parameter in BatchNormalization. I found about it in this link: https://github.com/shekkizh/neuralnetworks.thought-experiments/blob/master/Generative%20Models/GAN/Readme.md
    • Sort of interesting, in PyTorch, momentum parameter for BatchNorm is 0.1, according to the API documents, while in Keras it is 0.99. I'm not sure if 0.1 in PyTorch actually means 1 - 0.1. I didn't look into PyTorch backend implementation.
PyTorch implementation of DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration (BMVC 2021)

DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration [video] [paper] [supplementary] [data] [thesis] Introduction De

Natalie Lang 10 Dec 14, 2022
A Demo server serving Bert through ONNX with GPU written in Rust with <3

Demo BERT ONNX server written in rust This demo showcase the use of onnxruntime-rs on BERT with a GPU on CUDA 11 served by actix-web and tokenized wit

Xavier Tao 28 Jan 01, 2023
Homepage of paper: Paint Transformer: Feed Forward Neural Painting with Stroke Prediction, ICCV 2021.

Paint Transformer: Feed Forward Neural Painting with Stroke Prediction [Paper] [PaddlePaddle Implementation] Homepage of paper: Paint Transformer: Fee

442 Dec 16, 2022
Prometheus Exporter for data scraped from datenplattform.darmstadt.de

darmstadt-opendata-exporter Scrapes data from https://datenplattform.darmstadt.de and presents it in the Prometheus Exposition format. Pull requests w

Martin Weinelt 2 Apr 12, 2022
Official PyTorch implementation of SyntaSpeech (IJCAI 2022)

SyntaSpeech: Syntax-Aware Generative Adversarial Text-to-Speech | | | | 中文文档 This repository is the official PyTorch implementation of our IJCAI-2022

Zhenhui YE 116 Nov 24, 2022
PyTorch-centric library for evaluating and enhancing the robustness of AI technologies

Responsible AI Toolbox A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainabi

24 Dec 22, 2022
La source de mon module 'pyfade' disponible sur Pypi.

Version: 1.2 Introduction Pyfade est un module permettant de créer des dégradés colorés. Il vous permettra de changer chaque ligne de votre texte par

Billy 20 Sep 12, 2021
SimplEx - Explaining Latent Representations with a Corpus of Examples

SimplEx - Explaining Latent Representations with a Corpus of Examples Code Author: Jonathan Crabbé ( Jonathan Crabbé 14 Dec 15, 2022

[CVPR 2022 Oral] EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation

EPro-PnP EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation In CVPR 2022 (Oral). [paper] Hanshen

同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University) 842 Jan 04, 2023
Sdf sparse conv - Deep Learning on SDF for Classifying Brain Biomarkers

Deep Learning on SDF for Classifying Brain Biomarkers To reproduce the results f

1 Jan 25, 2022
Source code of article "Towards Toxic and Narcotic Medication Detection with Rotated Object Detector"

Towards Toxic and Narcotic Medication Detection with Rotated Object Detector Introduction This is the source code of article: Towards Toxic and Narcot

Woody. Wang 3 Oct 29, 2022
A custom-designed Spider Robot trained to walk using Deep RL in a PyBullet Simulation

SpiderBot_DeepRL Title: Implementation of Single and Multi-Agent Deep Reinforcement Learning Algorithms for a Walking Spider Robot Authors(s): Arijit

Arijit Dasgupta 9 Jul 28, 2022
an implementation of 3D Ken Burns Effect from a Single Image using PyTorch

3d-ken-burns This is a reference implementation of 3D Ken Burns Effect from a Single Image [1] using PyTorch. Given a single input image, it animates

Simon Niklaus 1.4k Dec 28, 2022
Official Implementation of SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations

Official Implementation of SimIPU SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations Since

Zhyever 37 Dec 01, 2022
Official Tensorflow implementation of U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation (ICLR 2020)

U-GAT-IT — Official TensorFlow Implementation (ICLR 2020) : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization fo

Junho Kim 6.2k Jan 04, 2023
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 01, 2023
Second-order Attention Network for Single Image Super-resolution (CVPR-2019)

Second-order Attention Network for Single Image Super-resolution (CVPR-2019) "Second-order Attention Network for Single Image Super-resolution" is pub

516 Dec 28, 2022
BTC-Generator - BTC Generator With Python

Что такое BTC-Generator? Это генератор чеков всеми любимого @BTC_BANKER_BOT Для

DoomGod 3 Aug 24, 2022
Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Mingrui Yu 3 Jan 07, 2022
Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)

DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) This repo is the official imp

如今我已剑指天涯 46 Dec 21, 2022