Autoregressive Models in PyTorch.

Overview

Autoregressive

This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like autoregressive models.

For presentation purposes, the WaveNet-like models are applied to randomized Fourier series (1D) and MNIST (2D). In the figure below, two WaveNet-like models with different training settings make an n-step prediction on a periodic time-series from the validation dataset.

Advanced functions show how to generate MNIST images and how to estimate the MNIST digit class (progressively) p(y=class|x) from observed pixels using a conditional WaveNet p(x|y=class) and Bayes rule. Left: sampled MNIST digits, right: progressive class estimates as more pixels are observed.

Note, this library does not implement (Gated) PixelCNNs, but unrolls images for the purpose of processing in WaveNet architectures. This works surprisingly well.

Features

Currently the following features are implemented

  • WaveNet architecture and training as proposed in (oord2016wavenet)
  • Conditioning support (oord2016wavenet)
  • Fast generation based on (paine2016fast)
  • Fully differentiable n-step unrolling in training (heindl2021autoreg)
  • 2D image generation, completion, classification, and progressive classification support based on MNIST dataset
  • A randomized Fourier dataset

Presentation

A detailed presentation with theoretical background, architectural considerations and experiments can be found below.

The presentation source as well as all generated images are public domain. In case you find them useful, please leave a citation (see References below). All presentation sources can be found in etc/presentation. The presentation is written in markdown using Marp, graph diagrams are created using yEd.

If you spot errors or if case you have suggestions for improvements, please let me know by opening an issue.

Installation

To install run,

pip install https://github.com/cheind/autoregressive.git#egg=autoregressive[dev]

which requires Python 3.9 and a recent PyTorch > 1.9

Usage

The library comes with a set of pre-trained models in models/. The following commands use those models to make various predictions. Many listed commands come with additional parameters; use --help to get additional information.

1D Fourier series

Sample new signals from scratch

python -m autoregressive.scripts.wavenet_signals sample --config "models/fseries_q127/config.yaml" --ckpt "models/fseries_q127/xxxxxx.ckpt" --condition 4 --horizon 1000

The default models conditions on the periodicity of the signal. For the pre-trained model the value range is int: [0..4], corresponding to periods of 5-10secs.


Predict the shape of partially observable curves.

python -m autoregressive.scripts.wavenet_signals predict --config "models/fseries_q127/config.yaml" --ckpt "models/fseries_q127/xxxxxx.ckpt" --horizon 1500 --num_observed 50 --num_trajectories 20 --num_curves 1 --show_confidence true

2D MNIST

To sample from the class-conditional model

python -m autoregressive.scripts.wavenet_mnist sample --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt"

Generate images conditioned on the digit class and observed pixels.

python -m autoregressive.scripts.wavenet_mnist predict --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt" 

To perform classification

python -m autoregressive.scripts.wavenet_mnist classify --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt"

Train

To train / reproduce a model

python -m autoregressive.scripts.train fit --config "models/mnist_q2/config.yaml"

Progress is logged to Tensorboard

tensorboard --logdir lightning_logs

To generate a training configuration file for a specific dataset use

python -m autoregressive.scripts.train fit --data autoregressive.datasets.FSeriesDataModule --print_config > fseries_config.yaml

Test

To run the tests

pytest

References

@misc{heindl2021autoreg, 
  title={Autoregressive Models}, 
  journal={PROFACTOR Journal Club}, 
  author={Heindl, Christoph},
  year={2021},
  howpublished={\url{https://github.com/cheind/autoregressive}}
}

@article{oord2016wavenet,
  title={Wavenet: A generative model for raw audio},
  author={Oord, Aaron van den and Dieleman, Sander and Zen, Heiga and Simonyan, Karen and Vinyals, Oriol and Graves, Alex and Kalchbrenner, Nal and Senior, Andrew and Kavukcuoglu, Koray},
  journal={arXiv preprint arXiv:1609.03499},
  year={2016}
}

@article{paine2016fast,
  title={Fast wavenet generation algorithm},
  author={Paine, Tom Le and Khorrami, Pooya and Chang, Shiyu and Zhang, Yang and Ramachandran, Prajit and Hasegawa-Johnson, Mark A and Huang, Thomas S},
  journal={arXiv preprint arXiv:1611.09482},
  year={2016}
}

@article{oord2016conditional,
  title={Conditional image generation with pixelcnn decoders},
  author={Oord, Aaron van den and Kalchbrenner, Nal and Vinyals, Oriol and Espeholt, Lasse and Graves, Alex and Kavukcuoglu, Koray},
  journal={arXiv preprint arXiv:1606.05328},
  year={2016}
}
Owner
Christoph Heindl
I am a scientist at PROFACTOR/JKU working at the interface between computer vision, robotics and deep learning.
Christoph Heindl
[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding

[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding Official Pytorch implementation of Negative Sample Matter

Multimedia Computing Group, Nanjing University 69 Dec 26, 2022
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
Unofficial implementation of HiFi-GAN+ from the paper "Bandwidth Extension is All You Need" by Su, et al.

HiFi-GAN+ This project is an unoffical implementation of the HiFi-GAN+ model for audio bandwidth extension, from the paper Bandwidth Extension is All

Brent M. Spell 134 Dec 30, 2022
Deep Q-Learning Network in pytorch (not actively maintained)

pytoch-dqn This project is pytorch implementation of Human-level control through deep reinforcement learning and I also plan to implement the followin

Hung-Tu Chen 342 Jan 01, 2023
BboxToolkit is a tiny library of special bounding boxes.

BboxToolkit is a light codebase collecting some practical functions for the special-shape detection, such as oriented detection

jbwang1997 73 Jan 01, 2023
Pretraining Representations For Data-Efficient Reinforcement Learning

Pretraining Representations For Data-Efficient Reinforcement Learning Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Ch

Mila 40 Dec 11, 2022
A Repository of Community-Driven Natural Instructions

A Repository of Community-Driven Natural Instructions TLDR; this repository maintains a community effort to create a large collection of tasks and the

AI2 244 Jan 04, 2023
PyTorch Implementation for Deep Metric Learning Pipelines

Easily Extendable Basic Deep Metric Learning Pipeline Karsten Roth ([email 

Karsten Roth 543 Jan 04, 2023
Stereo Radiance Fields (SRF): Learning View Synthesis for Sparse Views of Novel Scenes

Stereo Radiance Fields (SRF): Learning View Synthesis for Sparse Views of Novel Scenes

111 Dec 29, 2022
DeepHyper: Scalable Asynchronous Neural Architecture and Hyperparameter Search for Deep Neural Networks

What is DeepHyper? DeepHyper is a software package that uses learning, optimization, and parallel computing to automate the design and development of

DeepHyper Team 214 Jan 08, 2023
An alarm clock coded in Python 3 with Tkinter

Tkinter-Alarm-Clock An alarm clock coded in Python 3 with Tkinter. Run python3 Tkinter Alarm Clock.py in a terminal if you have Python 3. NOTE: This p

CodeMaster7000 1 Dec 25, 2021
This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems

Stability Audit This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems, Humantic

Data, Responsibly 4 Oct 27, 2022
QHack—the quantum machine learning hackathon

Official repo for QHack—the quantum machine learning hackathon

Xanadu 72 Dec 21, 2022
This is a collection of simple PyTorch implementations of neural networks and related algorithms. These implementations are documented with explanations,

labml.ai Deep Learning Paper Implementations This is a collection of simple PyTorch implementations of neural networks and related algorithms. These i

labml.ai 16.4k Jan 09, 2023
This code is a near-infrared spectrum modeling method based on PCA and pls

Nirs-Pls-Corn This code is a near-infrared spectrum modeling method based on PCA and pls 近红外光谱分析技术属于交叉领域,需要化学、计算机科学、生物科学等多领域的合作。为此,在(北邮邮电大学杨辉华老师团队)指导下

Fu Pengyou 6 Dec 17, 2022
PyTorch implementation of Pointnet2/Pointnet++

Pointnet2/Pointnet++ PyTorch Project Status: Unmaintained. Due to finite time, I have no plans to update this code and I will not be responding to iss

Erik Wijmans 1.2k Dec 29, 2022
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

HRNet 367 Dec 27, 2022
A MatConvNet-based implementation of the Fully-Convolutional Networks for image segmentation

MatConvNet implementation of the FCN models for semantic segmentation This package contains an implementation of the FCN models (training and evaluati

VLFeat.org 175 Feb 18, 2022
Unofficial implementation of the Involution operation from CVPR 2021

involution_pytorch Unofficial PyTorch implementation of "Involution: Inverting the Inherence of Convolution for Visual Recognition" by Li et al. prese

Rishabh Anand 46 Dec 07, 2022
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022