Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

Overview

StyleGAN 2 in PyTorch

Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch

Notice

I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care.

Requirements

I have tested on:

  • PyTorch 1.3.1
  • CUDA 10.1/10.2

Usage

First create lmdb datasets:

python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH

This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later.

Then you can train model in distributed settings

python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH

train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script.

SWAGAN

This implementation experimentally supports SWAGAN: A Style-based Wavelet-driven Generative Model (https://arxiv.org/abs/2102.06108). You can train SWAGAN by using

python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --arch swagan --batch BATCH_SIZE LMDB_PATH

As noted in the paper, SWAGAN trains much faster. (About ~2x at 256px.)

Convert weight from official checkpoints

You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints.

For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this:

python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl

This will create converted stylegan2-ffhq-config-f.pt file.

Generate samples

python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT

You should change your size (--size 256 for example) if you train with another dimension.

Project images to latent spaces

python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ...

Closed-Form Factorization (https://arxiv.org/abs/2007.06600)

You can use closed_form_factorization.py and apply_factor.py to discover meaningful latent semantic factor or directions in unsupervised manner.

First, you need to extract eigenvectors of weight matrices using closed_form_factorization.py

python closed_form_factorization.py [CHECKPOINT]

This will create factor file that contains eigenvectors. (Default: factor.pt) And you can use apply_factor.py to test the meaning of extracted directions

python apply_factor.py -i [INDEX_OF_EIGENVECTOR] -d [DEGREE_OF_MOVE] -n [NUMBER_OF_SAMPLES] --ckpt [CHECKPOINT] [FACTOR_FILE]

For example,

python apply_factor.py -i 19 -d 5 -n 10 --ckpt [CHECKPOINT] factor.pt

Will generate 10 random samples, and samples generated from latents that moved along 19th eigenvector with size/degree +-5.

Sample of closed form factorization

Pretrained Checkpoints

Link

I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences.

Samples

Sample with truncation

Sample from FFHQ. At 110,000 iterations. (trained on 3.52M images)

MetFaces sample with non-leaking augmentations

Sample from MetFaces with Non-leaking augmentations. At 150,000 iterations. (trained on 4.8M images)

Samples from converted weights

Sample from FFHQ

Sample from FFHQ (1024px)

Sample from LSUN Church

Sample from LSUN Church (256px)

License

Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2

Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity

To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid

Owner
Kim Seonghyeon
no side-effects
Kim Seonghyeon
Implementation of StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation in PyTorch

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Implementation of StyleSpace Analysis: Disentangled Controls for StyleGAN Ima

Xuanchi Ren 86 Dec 07, 2022
an implementation of softmax splatting for differentiable forward warping using PyTorch

softmax-splatting This is a reference implementation of the softmax splatting operator, which has been proposed in Softmax Splatting for Video Frame I

Simon Niklaus 338 Dec 28, 2022
The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper.

Intermdiate layer matters - SSL The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper. Downl

Aakash Kaku 35 Sep 19, 2022
Pose estimation with MoveNet Lightning

Pose Estimation With MoveNet Lightning MoveNet is the TensorFlow pre-trained model that identifies 17 different key points of the human body. It is th

Yash Vora 2 Jan 04, 2022
Vis2Mesh: Efficient Mesh Reconstruction from Unstructured Point Clouds of Large Scenes with Learned Virtual View Visibility ICCV2021

Vis2Mesh This is the offical repository of the paper: Vis2Mesh: Efficient Mesh Reconstruction from Unstructured Point Clouds of Large Scenes with Lear

71 Dec 25, 2022
Hyperparameter Optimization for TensorFlow, Keras and PyTorch

Hyperparameter Optimization for Keras Talos • Key Features • Examples • Install • Support • Docs • Issues • License • Download Talos radically changes

Autonomio 1.6k Dec 15, 2022
Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Time Using Noisy Proxies

Deconfounding Temporal Autoencoder (DTA) This is a repository for the paper "Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Tim

Milan Kuzmanovic 3 Feb 04, 2022
PSPNet in Chainer

PSPNet This is an unofficial implementation of Pyramid Scene Parsing Network (PSPNet) in Chainer. Training Requirement Python 3.4.4+ Chainer 3.0.0b1+

Shunta Saito 76 Dec 12, 2022
ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

Double-zh 45 Dec 19, 2022
CLIP: Connecting Text and Image (Learning Transferable Visual Models From Natural Language Supervision)

CLIP (Contrastive Language–Image Pre-training) Experiments (Evaluation) Model Dataset Acc (%) ViT-B/32 (Paper) CIFAR100 65.1 ViT-B/32 (Our) CIFAR100 6

Myeongjun Kim 52 Jan 07, 2023
Code for "Unsupervised Layered Image Decomposition into Object Prototypes" paper

DTI-Sprites Pytorch implementation of "Unsupervised Layered Image Decomposition into Object Prototypes" paper Check out our paper and webpage for deta

40 Dec 22, 2022
Learning from Synthetic Humans, CVPR 2017

Learning from Synthetic Humans (SURREAL) Gül Varol, Javier Romero, Xavier Martin, Naureen Mahmood, Michael J. Black, Ivan Laptev and Cordelia Schmid,

Gul Varol 538 Dec 18, 2022
a practicable framework used in Deep Learning. So far UDL only provide DCFNet implementation for the ICCV paper (Dynamic Cross Feature Fusion for Remote Sensing Pansharpening)

UDL UDL is a practicable framework used in Deep Learning (computer vision). Benchmark codes, results and models are available in UDL, please contact @

Xiao Wu 11 Sep 30, 2022
Civsim is a basic civilisation simulation and modelling system built in Python 3.8.

Civsim Introduction Civsim is a basic civilisation simulation and modelling system built in Python 3.8. It requires the following packages: perlin_noi

17 Aug 08, 2022
Code accompanying paper: Meta-Learning to Improve Pre-Training

Meta-Learning to Improve Pre-Training This folder contains code to run experiments in the paper Meta-Learning to Improve Pre-Training, NeurIPS 2021. P

28 Dec 31, 2022
CTF challenges from redpwnCTF 2021

redpwnCTF 2021 Challenges This repository contains challenges from redpwnCTF 2021 in the rCDS format; challenge information is in the challenge.yaml f

redpwn 27 Dec 07, 2022
Graph Convolutional Networks in PyTorch

Graph Convolutional Networks in PyTorch PyTorch implementation of Graph Convolutional Networks (GCNs) for semi-supervised classification [1]. For a hi

Thomas Kipf 4.5k Dec 31, 2022
Unofficial Pytorch Implementation of WaveGrad2

WaveGrad 2 — Unofficial PyTorch Implementation WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis Unofficial PyTorch+Lightning Implementati

MINDs Lab 104 Nov 29, 2022
Learning Representational Invariances for Data-Efficient Action Recognition

Learning Representational Invariances for Data-Efficient Action Recognition Official PyTorch implementation for Learning Representational Invariances

Virginia Tech Vision and Learning Lab 27 Nov 22, 2022