Official PyTorch implementation of Spatial Dependency Networks.

Related tags

Deep Learningsdn
Overview

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling



Example of SDN-VAE generated images.

Method Description

Spatial dependency network (SDN) is a novel neural architecture. It is based on spatial dependency layers which are designed for stacking deep neural networks that produce images e.g. generative models such as VAEs or GANs or segmentation, super-resolution and image-to-image-translation neural networks. SDNs improve upon celebrated CNNs by explicitly modeling spatial dependencies between feature vectors at each level of a deep neural network pipeline. Spatial dependency layers (i) explicitly introduce the inductive bias of spatial coherence; and (ii) offer improved modeling of long-range dependencies due to the unbounded receptive field. We applied SDN to two variants of VAE, one which we used to model image density (SDN-VAE) and one which we used to learn better disentangled representations. More generally, spatial dependency layers can be used as a drop-in replacement for convolutional layers in any image-generation-related tasks.

Graphical model of SDN layer.

Code Structure

.
├── checkpoints/               # where the model checkpoints will be stored
├── data/
     ├── ImageNet32/           # where ImageNet32 data is stored
     ├── CelebAHQ256/          # where Celeb data is stored
     ├── 3DShapes/             # where 3DShapes data is stored
     ├── lmdb_datasets.py      # LMDB data loading borrowed from https://github.com/NVlabs/NVAE
     ├── get_dataset.py        # auxiliary script for fetching data sets
├── figs/                      # figures from the paper
├── lib/
     ├── DensityVAE            # SDN-VAE which we used for density estimation
     ├── DisentanglementVAE    # VAE which we used for disentanglement
     ├── nn.py                 # The script which contains SDN and other neural net modules
     ├── probability.py        # probability models
     ├── utils.py              # utility functions
 ├── train.py                  # generic training script
 ├── evaluate.py               # the script for evaluation of trained models
 ├── train_cifar.sh            # for reproducing CIFAR10 experiments
 ├── train_celeb.sh            # for reproducing CelebAHQ256 experiments
 ├── train_imagenet.sh         # for reproducing ImageNet32 experiments
 ├── train_3dshapes.sh         # for reproducing 3DShapes experiments
 ├── requirements.txt
 ├── LICENSE
 └── README.md

Applying SDN layers to your neural network

To apply SDN layers to your framework it is sufficient that you integrate the 'lib/nn.py' file into your code. You can then import and utilize SDNLayer or ResSDNLayer (the residual variant) in the same way convolutional layer is utilized. Apart from PyTorch, no additional packages are required.

Tips & Tricks

If you would like to integrate SDN into your neural network, we recommend the following:

  • first design and debug your framework using vanilla CNN layers.
  • replace CNN layers one-by-one. Start with the lowest scale e.g. 4x4 or 8x8 to speed up debugging.
  • start with 1 or 2 directions, and then later on try using 4 directions.
  • larger number of features per SDN layers implies more expressive model which is more powerful but prone to overfitting.
  • a good idea is to use smaller number of SDN features on smaller scales and larger on larger scales.

Reproducing the experiments from the paper

Common to all experiments, you will need to install PyTorch and PyTorchLightning. The default logging system is based on Wandb but this can be changed in 'train.py'. In case you decide to use Wandb, you will need to install it and then login into your account: Follow a very simple procedure described here. To reproduce density estimation experiments you will need 8 TeslaV100 GPUs with 32Gb of memory. One way to alleviate high memory requirements is to accumulate gradient batches, however, the training will take much longer in that case. By default, you will need hardware that supports automatic mixed precision. In case your hardware does not support this, you will need to reduce the batch size, however note that the results will slightly deteriorate and that you will possibly need to reduce the learning rate too to avoid NaN values. For the disentanglement experiments, you will need a single GPU with >10Gb of memory. To install all the requirements use:

pip install -r requirements.txt

Note of caution: Ensure the right version of PyTorchLightning is used. We found multiple issues in the newer versions.

CIFAR10

The data will be automatically downloaded through PyTorch. To run the baselines that reproduce the results from the paper use:

bash train_cifar.sh
ImageNet32

To obtain the dataset go into the folder 'data/ImageNet32' and then run

bash get_imagenet_data.sh

To reproduce the experiments run:

bash train_imagenet.sh
CelebAHQ256

To obtain the dataset go into the folder 'data/CelebAHQ256' and then run

bash get_celeb_data.sh

The script is adapted from NVAE repo and is based on GLOW dataset. To reproduce the experiments run:

bash train_celeb.sh
3DShapes

To obtain the dataset follow the instructions on this GitHub repo. Place it into the 'data/3DShapes' directory. To reproduce the experiments run:

bash train_3dshapes.sh

Evaluation of trained models

To perform post hoc evaluation of your trained models, use 'evaluate.py' script and select flags corresponding to the evaluation task and the model you want to use. The evaluation can be performed on a single GPU of any type, though note that the batch size needs to be modified dependent on the available GPU memory. For the CelebAHQ256 dataset, you can download the checkpoint which contains one of the pre-trained models that we used in the paper from this link. For example, you can evaluate elbo and generate random samples by running:

python3 evaluate.py --model CelebAHQ256 --elbo --sampling

Citation

Please cite our paper if you use our code or if you re-implement our method:

@conference{miladinovic21sdn,
  title = {Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling},
  author = {Miladinović, {\DJ}or{\dj}e and Stanić, Aleksandar and Bauer, Stefan and Schmidhuber, J{\"u}rgen and Buhmann, Joachim M.},
  booktitle = {9th International Conference on Learning Representations (ICLR 2021)},
  month = may,
  year = {2021},
  month_numeric = {5}
}

Note that you might need to include the following line in your latex file:

\usepackage[T1]{fontenc}
Owner
Djordje Miladinovic
Machine learning R&D.
Djordje Miladinovic
Decoding the Protein-ligand Interactions Using Parallel Graph Neural Networks

Decoding the Protein-ligand Interactions Using Parallel Graph Neural Networks Requirements python 0.10+ rdkit 2020.03.3.0 biopython 1.78 openbabel 2.4

Neeraj Kumar 3 Nov 23, 2022
Concept drift monitoring for HA model servers.

{Fast, Correct, Simple} - pick three Easily compare training and production ML data & model distributions Goals Boxkite is an instrumentation library

98 Dec 15, 2022
pix2pix in tensorflow.js

pix2pix in tensorflow.js This repo is moved to https://github.com/yining1023/pix2pix_tensorflowjs_lite See a live demo here: https://yining1023.github

Yining Shi 47 Oct 04, 2022
Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".

Detecting Twenty-thousand Classes using Image-level Supervision Detic: A Detector with image classes that can use image-level labels to easily train d

Meta Research 1.3k Jan 04, 2023
A unified 3D Transformer Pipeline for visual synthesis

Overview This is the official repo for the paper: NÜWA: Visual Synthesis Pre-training for Neural visUal World creAtion. NÜWA is a unified multimodal p

Microsoft 2.6k Jan 06, 2023
Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021)

UNITE and UNITE+ Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021) Unbalanced Intrinsic Feature Transport for Exemplar-bas

Fangneng Zhan 183 Nov 09, 2022
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 214 Jan 03, 2023
Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval (NeurIPS'21)

Baleen Baleen is a state-of-the-art model for multi-hop reasoning, enabling scalable multi-hop search over massive collections for knowledge-intensive

Stanford Future Data Systems 22 Dec 05, 2022
A micro-game "flappy bird".

1-o-flappy A micro-game "flappy bird". Gameplays The game will be installed at /usr/bin . The name of it is "1-o-flappy". You can type "1-o-flappy" to

1 Nov 06, 2021
Projecting interval uncertainty through the discrete Fourier transform

Projecting interval uncertainty through the discrete Fourier transform This repo

1 Mar 02, 2022
ChebLieNet, a spectral graph neural network turned equivariant by Riemannian geometry on Lie groups.

ChebLieNet: Invariant spectral graph NNs turned equivariant by Riemannian geometry on Lie groups Hugo Aguettaz, Erik J. Bekkers, Michaël Defferrard We

haguettaz 12 Dec 10, 2022
Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite.

TFLite-HITNET-Stereo-depth-estimation Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite. Stereo depth e

Ibai Gorordo 22 Oct 20, 2022
Pytorch implementation of XRD spectral identification from COD database

XRDidentifier Pytorch implementation of XRD spectral identification from COD database. Details will be explained in the paper to be submitted to NeurI

Masaki Adachi 4 Jan 07, 2023
A library of multi-agent reinforcement learning components and systems

Mava: a research framework for distributed multi-agent reinforcement learning Table of Contents Overview Getting Started Supported Environments System

InstaDeep Ltd 463 Dec 23, 2022
A framework for analyzing computer vision models with simulated data

3DB: A framework for analyzing computer vision models with simulated data Paper Quickstart guide Blog post Installation Follow instructions on: https:

3DB 112 Jan 01, 2023
A curated list and survey of awesome Vision Transformers.

English | 简体中文 A curated list and survey of awesome Vision Transformers. You can use mind mapping software to open the mind mapping source file. You c

OpenMMLab 281 Dec 21, 2022
Pytorch implementation of the paper "Optimization as a Model for Few-Shot Learning"

Optimization as a Model for Few-Shot Learning This repo provides a Pytorch implementation for the Optimization as a Model for Few-Shot Learning paper.

Albert Berenguel Centeno 238 Jan 04, 2023
A collection of Google research projects related to Federated Learning and Federated Analytics.

Federated Research Federated Research is a collection of research projects related to Federated Learning and Federated Analytics. Federated learning i

Google Research 483 Jan 05, 2023
PyTorch Implementation for AAAI'21 "Do Response Selection Models Really Know What's Next? Utterance Manipulation Strategies for Multi-turn Response Selection"

UMS for Multi-turn Response Selection Implements the model described in the following paper Do Response Selection Models Really Know What's Next? Utte

Taesun Whang 47 Nov 22, 2022
Collection of generative models in Pytorch version.

pytorch-generative-model-collections Original : [Tensorflow version] Pytorch implementation of various GANs. This repository was re-implemented with r

Hyeonwoo Kang 2.4k Dec 31, 2022