Natural Posterior Network: Deep Bayesian Predictive Uncertainty for Exponential Family Distributions

Overview

Natural Posterior Network

This repository provides the official implementation of the Natural Posterior Network (NatPN) and the Natural Posterior Ensemble (NatPE) as presented in the following paper:

Natural Posterior Network: Deep Bayesian Predictive Uncertainty for Exponential Family Distributions
Bertrand Charpentier*, Oliver Borchert*, Daniel Zügner, Simon Geisler, Stephan Günnemann
International Conference on Learning Representations, 2022

Features

The implementation of NatPN that is found in this repository provides the following features:

  • High-level estimator interface that makes NatPN as easy to use as Scikit-learn estimators
  • Simple bash script to train and evaluate NatPN
  • Ready-to-use PyTorch Lightning data modules with 8 of the 9 datasets used in the paper*

In addition, we provide a public Weights & Biases project. This project will be filled with training and evaluation runs that allow you (1) to inspect the performance of different NatPN models and (2) to download the model parameters. See the example notebook for instructions on how to use such a pretrained model.

*The Kin8nm dataset is not included as it has disappeared from the UCI Repository.

Installation

Prior to installation, you may want to install all dependencies (Python, CUDA, Poetry). If you are running on an AWS EC2 instance with Ubuntu 20.04, you can use the provided bash script:

sudo bash bin/setup-ec2.sh

In order to use the code in this repository, you should first clone the repository:

git clone [email protected]:borchero/natural-posterior-network.git natpn

Then, in the root of the repository, you can install all dependencies via Poetry:

poetry install

Quickstart

Shell Script

To simply train and evaluate NatPN on a particular dataset, you can use the train shell script. For example, to train and evaluate NatPN on the Sensorless Drive dataset, you can run the following command in the root of the repository:

poetry run train --dataset sensorless-drive

The dataset gets downloaded automatically the first time this command is called. The performance metrics of the trained model is printed to the console and the trained model is discarded. In order to track both the metrics and the model parameters via Weights & Biases, use the following command:

poetry run train --dataset sensorless-drive --experiment first-steps

To list all options of the shell script, simply run:

poetry run train --help

This command will also provide explanations for all the parameters that can be passed.

Estimator

If you want to use NatPN from your code, the easiest way to get started is to use the Scikit-learn-like estimator:

from natpn import NaturalPosteriorNetwork

The documentation of the estimator's __init__ method provides a comprehensive overview of all the configuration options. For a simple example of using the estimator, refer to the example notebook.

Module

If you need even more customization, you can use natpn.nn.NaturalPosteriorNetworkModel directly. The natpn.nn package provides plenty of documentation and allows to configure your NatPN model as much as possible.

Further, the natpn.model package provides PyTorch Lightning modules which allow you to train, evaluate, and fine-tune models.

Running Hyperparameter Searches

If you want to run hyperparameter searches on a local Slurm cluster, you can use the files provided in the sweeps directory. To run the grid search, simply execute the file:

poetry run python sweeps/<file>

To make sure that your experiment is tracked correctly, you should also set the WANDB_PROJECT environment variable in a place that is read by the slurm script (found in sweeps/slurm).

Feel free to adapt the scripts to your liking to run your own hyperparameter searches.

Citation

If you are using the model or the code in this repository, please cite the following paper:

@inproceedings{natpn,
    title={{Natural} {Posterior} {Network}: {Deep} {Bayesian} {Predictive} {Uncertainty} for {Exponential} {Family} {Distributions}},
    author={Charpentier, Bertrand and Borchert, Oliver and Z\"{u}gner, Daniel and Geisler, Simon and G\"{u}nnemann, Stephan},
    booktitle={International Conference on Learning Representations},
    year={2022}
}

Contact Us

If you have any questions regarding the code, please contact us via mail.

License

The code in this repository is licensed under the MIT License.

Owner
Oliver Borchert
MSc Data Engineering and Analytics @ TUM | Applied Science Intern @ AWS
Oliver Borchert
To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beginners, intermediates as well as experts

JaxTon 💯 JAX exercises Mission 🚀 To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beg

Rohan Rao 512 Jan 01, 2023
On the model-based stochastic value gradient for continuous reinforcement learning

On the model-based stochastic value gradient for continuous reinforcement learning This repository is by Brandon Amos, Samuel Stanton, Denis Yarats, a

Facebook Research 46 Dec 15, 2022
Line-level Handwritten Text Recognition (HTR) system implemented with TensorFlow.

Line-level Handwritten Text Recognition with TensorFlow This model is an extended version of the Simple HTR system implemented by @Harald Scheidl and

Hoàng Tùng Lâm (Linus) 72 May 07, 2022
Simple tool to combine(merge) onnx models. Simple Network Combine Tool for ONNX.

snc4onnx Simple tool to combine(merge) onnx models. Simple Network Combine Tool for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools 1.

Katsuya Hyodo 8 Oct 13, 2022
The official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang Gong, Yi Ma. "Fully Convolutional Line Parsing." *.

F-Clip — Fully Convolutional Line Parsing This repository contains the official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang

Xili Dai 115 Dec 28, 2022
Official implementation of "Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform", ICCV 2021

Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform This repository is the implementation of "Variable-Rate Deep Image C

Myungseo Song 47 Dec 13, 2022
Invariant Causal Prediction for Block MDPs

MISA Abstract Generalization across environments is critical to the successful application of reinforcement learning algorithms to real-world challeng

Meta Research 41 Sep 17, 2022
Deep-learning X-Ray Micro-CT image enhancement, pore-network modelling and continuum modelling

EDSR modelling A Github repository for deep-learning image enhancement, pore-network and continuum modelling from X-Ray Micro-CT images. The repositor

Samuel Jackson 7 Nov 03, 2022
A small fun project using python OpenCV, mediapipe, and pydirectinput

Here I tried a small fun project using python OpenCV, mediapipe, and pydirectinput. Here we can control moves car game when yellow color come to right box (press key 'd') left box (press key 'a') lef

Sameh Elisha 3 Nov 17, 2022
Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours

tsp-streamlit Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours.

4 Nov 05, 2022
PointCNN: Convolution On X-Transformed Points (NeurIPS 2018)

PointCNN: Convolution On X-Transformed Points Created by Yangyan Li, Rui Bu, Mingchao Sun, Wei Wu, Xinhan Di, and Baoquan Chen. Introduction PointCNN

Yangyan Li 1.3k Dec 21, 2022
A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or simply to separate onnx files to any size you want.

sne4onnx A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or

Katsuya Hyodo 10 Aug 30, 2022
A PyTorch re-implementation of the paper 'Exploring Simple Siamese Representation Learning'. Reproduced the 67.8% Top1 Acc on ImageNet.

Exploring simple siamese representation learning This is a PyTorch re-implementation of the SimSiam paper on ImageNet dataset. The results match that

Taojiannan Yang 72 Nov 09, 2022
Easy-to-use,Modular and Extendible package of deep-learning based CTR models .

DeepCTR DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models along with lots of core components layers which can

浅梦 6.6k Jan 08, 2023
Our VMAgent is a platform for exploiting Reinforcement Learning (RL) on Virtual Machine (VM) scheduling tasks.

VMAgent is a platform for exploiting Reinforcement Learning (RL) on Virtual Machine (VM) scheduling tasks. VMAgent is constructed based on one month r

56 Dec 12, 2022
Delving into Localization Errors for Monocular 3D Object Detection, CVPR'2021

Delving into Localization Errors for Monocular 3D Detection By Xinzhu Ma, Yinmin Zhang, Dan Xu, Dongzhan Zhou, Shuai Yi, Haojie Li, Wanli Ouyang. Intr

XINZHU.MA 124 Jan 04, 2023
This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021.

Off-Belief Learning Introduction This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021. Environment Setup

Facebook Research 32 Jan 05, 2023
CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer

CSAW-M This repository contains code for CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer. Source code for tr

Yue Liu 7 Oct 11, 2022
SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

SPLADE 🍴 + 🥄 = 🔎 This repository contains the weights for four models as well as the code for running inference for our two papers: [v1]: SPLADE: S

NAVER 170 Dec 28, 2022
A highly efficient, fast, powerful and light-weight anime downloader and streamer for your favorite anime.

AnimDL - Download & Stream Your Favorite Anime AnimDL is an incredibly powerful tool for downloading and streaming anime. Core features Abuses the dev

KR 759 Jan 08, 2023