GAN example for Keras. Cuz MNIST is too small and there should be something more realistic.

Overview

Keras-GAN-Animeface-Character

GAN example for Keras. Cuz MNIST is too small and there should an example on something more realistic.

Some results

Training for 22 epochs

Youtube Video, click on the image

Training for 22 epochs

Loss graph for 5000 mini-batches

Loss graph

1 mini-batch = 64 images. Dataset = 14490, hence 5000 mini-batches is approximately 22 epochs.

Some outputs of 5000th min-batch

Some ouptputs of 5000th mini-batch

Some training images

Some inputs

Useful resources, before you go on

How to run this example

Setup

  • My environment: Python 3.6 + Keras 2.0.4 + Tensorflow 1.x
    • If you are on Keras 2.0.0, you need to update it otherwise BatchNormalization() will cause bug, saying "you need to pass float to input" or something like that from Tensorflow back end.
  • Use virtualenv to initialize a similar environment (python and dependencies):
pip install virtualenv
virtualenv -p <PATH_TO_BIN_DIR>/python3.6 venv
source venv/bin/activate
pip install -r requirements.txt
  • I HATE making a program that has so many command line parameters to pass. Many of the parameters are there in the scripts. Adjust the script as you need. The "main()" function is at the bottom of the script as people do in C/C++
  • Most global parameters are defined in args.py.
    • They are defined as class variables not instance variables so you may have trouble running/training multiple instances of the GAN with different parameters. (which is very unlikely to happen)
  • Download dataset from http://www.nurs.or.jp/~nagadomi/animeface-character-dataset/
    • Extract it to this directory so that the scipt can find ./animeface-character-dataset/thumb/
    • Any dataset should work in principle but GAN is sensitive to hyperparameters and may not work on yours. I tuned the parameters for animeface-character-dataset.

Preprocessing

  • Run the preprocessing script. It saves training time to resize/scale the input than doing those tasks on the fly in the training loop.
    • ./data.py
    • The image, when loaded from PNG files, the RGB values have [0, 255]. (uint8 type). data.py will collect the images, resize the images to 64x64 and scale the RGB values so that they will be in [-1.0, 1.0] range.
    • Data.py will only sample a subset of the dataset if configured to do so. The size of the subset is determined by dataset_sz defined in args.py
    • The images will be written to data.hdf5.
      • Made it small to verify the training is working.
      • You can increase it but you need to adjust the network sizes accordingly.
    • Again, which files to read is defined in the script at the bottom, not by sys.argv.
  • You need a large enough dataset. Otherwise the discriminator will sort of "memorize" the true data and reject all that's generated.

Training

  • Open gan.py then at the bottom, uncomment train_autoenc() if you wish.
    • This is useful for seeing the generator network's capability to reproduce the input.
    • The auto-encoder will be trained on input images.
    • The output will be blurry, as the auto-encoder having mean-squared-error loss. (This is why GAN got invented in the first place!)
  • To run training, modify main() so that train_gan() is uncommented.
  • The script will dump reals.png and fakes.png every 10 epoch so that you can see how the training is going.
  • The training takes a while. For this example on Anime Face dataset, it took about 10000 mini-batches to get good results.
    • If you see only uniform color or "modern art" until 2000 then the training is not working!
  • The script also dumps weights every 10 batches. Utilize them to save training time. Weights before diverging is preferred :) Uncomment load_weights() in train_gan().

Training tips

What I experienced during my training of GAN.

  • As described in GAN Hacks, discriminator should be ahead of the generator so that the generator can be "guided" by the discriminator.
  • If you look at loss graph at https://github.com/osh/KerasGAN, they had gen loss in range of 2 to 4. Their training worked well. The discriminator loss is low, arond 0.1.
  • You'll need trial and error to get the hyper-pameters right so that the training stays in the stable, balanced zone. That includes learning rate of D and G, momentums, etc.
  • The convergence is quite sensitive with LR, beware!
  • If things go well, the discriminator loss for detecting real/fake = dloss0/dloss1 should be less than or around 0.1, which means it is good at telling whether the input is real or fake.
  • If learning rate is too high, the discriminator will diverge and one of the loss will get high and will not fall. Training fails in this case.
  • If you make LR too small, it will only slow the learning and will not prevent other issues such as oscillation. It only needs to be lower than certain threshold that is data dependent.
  • If adjusting LR doesn't work, it could be lack of complexity in the discriminator layer. Add more layers, or some other parameters. It could be anything :( Good luck!
  • On the other hand, generator loss will be relatively higher than discriminator loss. In this script, it oscillates in range 0.1 to 4.
  • If you see any of the D loss staying > 15 (when batch size is 32) the training is screwed.
  • In case of G loss > 15, see if it escapes within 30 batches. If it stays there for too long, it isn't good, I think.
  • In case you're seeing high G loss, it could mean it can't keep up with discriminator. You might need to increase LR. (Must be slower than discriminator though)
  • One final piece of the training I was missing was the parameter in BatchNormalization. I found about it in this link: https://github.com/shekkizh/neuralnetworks.thought-experiments/blob/master/Generative%20Models/GAN/Readme.md
    • Sort of interesting, in PyTorch, momentum parameter for BatchNorm is 0.1, according to the API documents, while in Keras it is 0.99. I'm not sure if 0.1 in PyTorch actually means 1 - 0.1. I didn't look into PyTorch backend implementation.
An OpenAI-Gym Package for Training and Testing Reinforcement Learning algorithms with OpenSim Models

Authors: Utkarsh A. Mishra and Dr. Dimitar Stanev Advisors: Dr. Dimitar Stanev and Prof. Auke Ijspeert, Biorobotics Laboratory (BioRob), EPFL Video Pl

Utkarsh Mishra 16 Dec 13, 2022
Code release for NeRF (Neural Radiance Fields)

NeRF: Neural Radiance Fields Project Page | Video | Paper | Data Tensorflow implementation of optimizing a neural representation for a single scene an

6.5k Jan 01, 2023
Controlling Hill Climb Racing with Hand Tacking

Controlling Hill Climb Racing with Hand Tacking Opened Palm for Gas Closed Palm for Brake

Rohit Ingole 3 Jan 18, 2022
Bridging Vision and Language Model

BriVL BriVL (Bridging Vision and Language Model) 是首个中文通用图文多模态大规模预训练模型。BriVL模型在图文检索任务上有着优异的效果,超过了同期其他常见的多模态预训练模型(例如UNITER、CLIP)。 BriVL论文:WenLan: Bridgi

235 Dec 27, 2022
Codes and models for the paper "Learning Unknown from Correlations: Graph Neural Network for Inter-novel-protein Interaction Prediction".

GNN_PPI Codes and models for the paper "Learning Unknown from Correlations: Graph Neural Network for Inter-novel-protein Interaction Prediction". Lear

Ursa Zrimsek 2 Dec 14, 2022
Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

41 Jan 03, 2023
This repository contains the re-implementation of our paper deSpeckNet: Generalizing Deep Learning Based SAR Image Despeckling

deSpeckNet-TF-GEE This repository contains the re-implementation of our paper deSpeckNet: Generalizing Deep Learning Based SAR Image Despeckling publi

Adugna Mullissa 16 Sep 07, 2022
Code Release for ICCV 2021 (oral), "AdaFit: Rethinking Learning-based Normal Estimation on Point Clouds"

AdaFit: Rethinking Learning-based Normal Estimation on Point Clouds (ICCV 2021 oral) **Project Page | Arxiv ** Runsong Zhu¹, Yuan Liu², Zhen Dong¹, Te

40 Dec 30, 2022
Using LSTM write Tang poetry

本教程将通过一个示例对LSTM进行介绍。通过搭建训练LSTM网络,我们将训练一个模型来生成唐诗。本文将对该实现进行详尽的解释,并阐明此模型的工作方式和原因。并不需要过多专业知识,但是可能需要新手花一些时间来理解的模型训练的实际情况。为了节省时间,请尽量选择GPU进行训练。

56 Dec 15, 2022
The official TensorFlow implementation of the paper Action Transformer: A Self-Attention Model for Short-Time Pose-Based Human Action Recognition

Action Transformer A Self-Attention Model for Short-Time Human Action Recognition This repository contains the official TensorFlow implementation of t

PIC4SeRCentre 20 Jan 03, 2023
Specificity-preserving RGB-D Saliency Detection

Specificity-preserving RGB-D Saliency Detection Authors: Tao Zhou, Huazhu Fu, Geng Chen, Yi Zhou, Deng-Ping Fan, and Ling Shao. 1. Preface This reposi

Tao Zhou 35 Jan 08, 2023
Spectral Temporal Graph Neural Network (StemGNN in short) for Multivariate Time-series Forecasting

Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting This repository is the official implementation of Spectral Temporal Gr

Microsoft 306 Dec 29, 2022
An unofficial styleguide and best practices summary for PyTorch

A PyTorch Tools, best practices & Styleguide This is not an official style guide for PyTorch. This document summarizes best practices from more than a

IgorSusmelj 1.5k Jan 05, 2023
Deep-Learning-Book-Chapter-Summaries - Attempting to make the Deep Learning Book easier to understand.

Deep-Learning-Book-Chapter-Summaries This repository provides a summary for each chapter of the Deep Learning book by Ian Goodfellow, Yoshua Bengio an

Aman Dalmia 1k Dec 27, 2022
Multimodal Descriptions of Social Concepts: Automatic Modeling and Detection of (Highly Abstract) Social Concepts evoked by Art Images

MUSCO - Multimodal Descriptions of Social Concepts Automatic Modeling of (Highly Abstract) Social Concepts evoked by Art Images This project aims to i

0 Aug 22, 2021
An expansion for RDKit to read all types of files in one line

RDMolReader An expansion for RDKit to read all types of files in one line How to use? Add this single .py file to your project and import MolFromFile(

Ali Khodabandehlou 1 Dec 18, 2021
Gradient Step Denoiser for convergent Plug-and-Play

Source code for the paper "Gradient Step Denoiser for convergent Plug-and-Play"

Samuel Hurault 11 Sep 17, 2022
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior. The code will release soon. Implementation Python3 PyTorch=1.0 NVIDIA GPU+

FengZhang 34 Dec 04, 2022
A Quick and Dirty Progressive Neural Network written in TensorFlow.

prog_nn .▄▄ · ▄· ▄▌ ▐ ▄ ▄▄▄· ▐ ▄ ▐█ ▀. ▐█▪██▌•█▌▐█▐█ ▄█▪ •█▌▐█ ▄▀▀▀█▄▐█▌▐█▪▐█▐▐▌ ██▀

SynPon 53 Dec 12, 2022
Pytorch implementation of “Recursive Non-Autoregressive Graph-to-Graph Transformer for Dependency Parsing with Iterative Refinement”

Graph-to-Graph Transformers Self-attention models, such as Transformer, have been hugely successful in a wide range of natural language processing (NL

Idiap Research Institute 40 Aug 14, 2022