Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

Overview

Travis CI

Single Image Super-Resolution with EDSR, WDSR and SRGAN

A Tensorflow 2.x based implementation of

This is a complete re-write of the old Keras/Tensorflow 1.x based implementation available here. Some parts are still work in progress but you can already train models as described in the papers via a high-level training API. Furthermore, you can also fine-tune EDSR and WDSR models in an SRGAN context. Training and usage examples are given in the notebooks

A DIV2K data provider automatically downloads DIV2K training and validation images of given scale (2, 3, 4 or 8) and downgrade operator ("bicubic", "unknown", "mild" or "difficult").

Important: if you want to evaluate the pre-trained models with a dataset other than DIV2K please read this comment (and replies) first.

Environment setup

Create a new conda environment with

conda env create -f environment.yml

and activate it with

conda activate sisr

Introduction

You can find an introduction to single-image super-resolution in this article. It also demonstrates how EDSR and WDSR models can be fine-tuned with SRGAN (see also this section).

Getting started

Examples in this section require following pre-trained weights for running (see also example notebooks):

Pre-trained weights

  • weights-edsr-16-x4.tar.gz
    • EDSR x4 baseline as described in the EDSR paper: 16 residual blocks, 64 filters, 1.52M parameters.
    • PSNR on DIV2K validation set = 28.89 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-wdsr-b-32-x4.tar.gz
    • WDSR B x4 custom model: 32 residual blocks, 32 filters, expansion factor 6, 0.62M parameters.
    • PSNR on DIV2K validation set = 28.91 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-srgan.tar.gz
    • SRGAN as described in the SRGAN paper: 1.55M parameters, trained with VGG54 content loss.

After download, extract them in the root folder of the project with

tar xvfz weights-<...>.tar.gz

EDSR

from model import resolve_single
from model.edsr import edsr

from utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('demo/0851x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-edsr

WDSR

from model.wdsr import wdsr_b

model = wdsr_b(scale=4, num_res_blocks=32)
model.load_weights('weights/wdsr-b-32-x4/weights.h5')

lr = load_image('demo/0829x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-wdsr

Weight normalization in WDSR models is implemented with the new WeightNormalization layer wrapper of Tensorflow Addons. In its latest version, this wrapper seems to corrupt weights when running model.predict(...). A workaround is to set model.run_eagerly = True or compile the model with model.compile(loss='mae') in advance. This issue doesn't arise when calling the model directly with model(...) though. To be further investigated ...

SRGAN

from model.srgan import generator

model = generator()
model.load_weights('weights/srgan/gan_generator.h5')

lr = load_image('demo/0869x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-srgan

DIV2K dataset

For training and validation on DIV2K images, applications should use the provided DIV2K data loader. It automatically downloads DIV2K images to .div2k directory and converts them to a different format for faster loading.

Training dataset

from data import DIV2K

train_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='train')      # Training dataset are images 001 - 800
                     
# Create a tf.data.Dataset          
train_ds = train_loader.dataset(batch_size=16,         # batch size as described in the EDSR and WDSR papers
                                random_transform=True, # random crop, flip, rotate as described in the EDSR paper
                                repeat_count=None)     # repeat iterating over training images indefinitely

# Iterate over LR/HR image pairs                                
for lr, hr in train_ds:
    # .... 

Crop size in HR images is 96x96.

Validation dataset

from data import DIV2K

valid_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='valid')      # Validation dataset are images 801 - 900
                     
# Create a tf.data.Dataset          
valid_ds = valid_loader.dataset(batch_size=1,           # use batch size of 1 as DIV2K images have different size
                                random_transform=False, # use DIV2K images in original size 
                                repeat_count=1)         # 1 epoch
                                
# Iterate over LR/HR image pairs                                
for lr, hr in valid_ds:
    # ....                                 

Training

The following training examples use the training and validation datasets described earlier. The high-level training API is designed around steps (= minibatch updates) rather than epochs to better match the descriptions in the papers.

EDSR

from model.edsr import edsr
from train import EdsrTrainer

# Create a training context for an EDSR x4 model with 16 
# residual blocks.
trainer = EdsrTrainer(model=edsr(scale=4, num_res_blocks=16), 
                      checkpoint_dir=f'.ckpt/edsr-16-x4')
                      
# Train EDSR model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)
              
# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/edsr-16-x4/weights.h5')                                    

Interrupting training and restarting it again resumes from the latest saved checkpoint. The trained Keras model can be accessed with trainer.model.

WDSR

from model.wdsr import wdsr_b
from train import WdsrTrainer

# Create a training context for a WDSR B x4 model with 32 
# residual blocks.
trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'.ckpt/wdsr-b-8-x4')

# Train WDSR B model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)

# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/wdsr-b-32-x4/weights.h5')

SRGAN

Generator pre-training

from model.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('weights/srgan/pre_generator.h5')

Generator fine-tuning (GAN)

from model.srgan import generator, discriminator
from train import SrganTrainer

# Create a new generator and init it with pre-trained weights.
gan_generator = generator()
gan_generator.load_weights('weights/srgan/pre_generator.h5')

# Create a training context for the GAN (generator + discriminator).
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())

# Train the GAN with 200,000 steps.
gan_trainer.train(train_ds, steps=200000)

# Save weights of generator and discriminator.
gan_trainer.generator.save_weights('weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('weights/srgan/gan_discriminator.h5')

SRGAN for fine-tuning EDSR and WDSR models

It is also possible to fine-tune EDSR and WDSR x4 models with SRGAN. They can be used as drop-in replacement for the original SRGAN generator. More details in this article.

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('weights/wdsr-b-16-32/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
Owner
Martin Krasser
Freelance machine learning engineer, software developer and consultant. Mountainbike freerider, bass guitar player.
Martin Krasser
UPSNet: A Unified Panoptic Segmentation Network

UPSNet: A Unified Panoptic Segmentation Network Introduction UPSNet is initially described in a CVPR 2019 oral paper. Disclaimer This repository is te

Uber Research 622 Dec 26, 2022
Extracting and filtering paraphrases by bridging natural language inference and paraphrasing

nli2paraphrases Source code repository accompanying the preprint Extracting and filtering paraphrases by bridging natural language inference and parap

Matej Klemen 1 Mar 09, 2022
A Keras implementation of CapsNet in the paper: Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules

NOTE This implementation is fork of https://github.com/XifengGuo/CapsNet-Keras , applied to IMDB texts reviews dataset. CapsNet-Keras A Keras implemen

Lauro Moraes 5 Oct 23, 2022
This project is the official implementation of our accepted ICLR 2021 paper BiPointNet: Binary Neural Network for Point Clouds.

BiPointNet: Binary Neural Network for Point Clouds Created by Haotong Qin, Zhongang Cai, Mingyuan Zhang, Yifu Ding, Haiyu Zhao, Shuai Yi, Xianglong Li

Haotong Qin 59 Dec 17, 2022
kapre: Keras Audio Preprocessors

Kapre Keras Audio Preprocessors - compute STFT, ISTFT, Melspectrogram, and others on GPU real-time. Tested on Python 3.6 and 3.7 Why Kapre? vs. Pre-co

Keunwoo Choi 867 Dec 29, 2022
Alignment Attention Fusion framework for Few-Shot Object Detection

AAF framework Framework generalities This repository contains the code of the AAF framework proposed in this paper. The main idea behind this work is

Pierre Le Jeune 20 Dec 16, 2022
Classifying audio using Wavelet transform and deep learning

Audio Classification using Wavelet Transform and Deep Learning A step-by-step tutorial to classify audio signals using continuous wavelet transform (C

Aditya Dutt 17 Nov 29, 2022
Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation.

Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation. It was introduced in Wright, Logan G. & Onodera, Tatsuhiro et al. (2021)1 to train Physical Neural Networ

McMahon Lab 230 Jan 05, 2023
App customer segmentation cohort rfm clustering

CUSTOMER SEGMENTATION COHORT RFM CLUSTERING TỔNG QUAN VỀ HỆ THỐNG DỮ LIỆU Nên chuyển qua theme màu dark thì sẽ nhìn đẹp hơn https://customer-segmentat

hieulmsc 3 Dec 18, 2021
This package contains a PyTorch Implementation of IB-GAN of the submitted paper in AAAI 2021

The PyTorch implementation of IB-GAN model of AAAI 2021 This package contains a PyTorch implementation of IB-GAN presented in the submitted paper (IB-

Insu Jeon 9 Mar 30, 2022
A Deep Learning Framework for Neural Derivative Hedging

NNHedge NNHedge is a PyTorch based framework for Neural Derivative Hedging. The following repository was implemented to ease the experiments of our pa

GUIJIN SON 17 Nov 14, 2022
Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line

NAVER/LINE Vision 357 Jan 04, 2023
Pytoydl: A toy deep learning framework built upon numpy.

Documents: https://pytoydl.readthedocs.io/zh/latest/ Pytoydl A toy deep learning framework built upon numpy. You can star this repository to keep trac

28 Dec 10, 2022
Dense Unsupervised Learning for Video Segmentation (NeurIPS*2021)

Dense Unsupervised Learning for Video Segmentation This repository contains the official implementation of our paper: Dense Unsupervised Learning for

Visual Inference Lab @TU Darmstadt 173 Dec 26, 2022
《K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters》(2020)

K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters This repository is the implementation of the paper "K-Adapter: Infusing Knowledge

Microsoft 118 Dec 13, 2022
Implementation of Uformer, Attention-based Unet, in Pytorch

Uformer - Pytorch Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection. This repository wi

Phil Wang 72 Dec 19, 2022
Fastshap: A fast, approximate shap kernel

fastshap: A fast, approximate shap kernel fastshap was designed to be: Fast Calculating shap values can take an extremely long time. fastshap utilizes

Samuel Wilson 22 Sep 24, 2022
Super Pix Adv - Offical implemention of Robust Superpixel-Guided Attentional Adversarial Attack (CVPR2020)

Super_Pix_Adv Offical implemention of Robust Superpixel-Guided Attentional Adver

DLight 8 Oct 26, 2022
[NeurIPS '21] Adversarial Attacks on Graph Classification via Bayesian Optimisation (GRABNEL)

Adversarial Attacks on Graph Classification via Bayesian Optimisation @ NeurIPS 2021 This repository contains the official implementation of GRABNEL,

Xingchen Wan 12 Dec 23, 2022
Trajectory Variational Autoencder baseline for Multi-Agent Behavior challenge 2022

MABe_2022_TVAE: a Trajectory Variational Autoencoder baseline for the 2022 Multi-Agent Behavior challenge This repository contains jupyter notebooks t

Andrew Ulmer 15 Nov 08, 2022