Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Overview

Infinitely Deep Bayesian Neural Networks with SDEs

This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.

Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.

Installation

This library runs on jax==0.1.77 and torch==1.6.0. To install all other requirements:

pip install -r requirements.txt

Note: Package versions may change, refer to official JAX installation instructions here.

JaxSDE: Differentiable SDE Solvers in JAX

The jaxsde library contains SDE solvers in the Ito and Stratonovich form. Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun} (strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE). Stochastic adjoint (sdeint_ito) training mode does not work efficiently yet, use sdeint_ito_fixed_grid for now. Tradeoff solver speed for precision during training or inference by adjusting --nsteps <# steps>.

Usage

Default solver: Backpropagation through the solver.

from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid

y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")

Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.

from jaxsde.jaxsde.sdeint import sdeint_ito

y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")

Brax: Bayesian SDE Framework in JAX

Implementation of composable Bayesian layers in the stax API. Our SDE Bayesian layers can be used with the SDEBNN block composed with multiple parameterizations of time-dependent layers in diffeq_layers. Sticking-the-landing (STL) trick can be enabled during training with --stl for improving convergence rate. Augment the inputs by a custom amount --aug <integer>, set the number of samples averaged over with --nsamples <integer>. If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad and gradient checkpointing: --remat.

Samples from SDEBNN-learned predictive prior and posterior density distributions.

Usage

All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn instead).

Toy 1D regression to learn complex posteriors:

python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou

Image Classification:

To train an SDEBNN model:

python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1

To train a ResNet baseline, specify --model resnet and for a Bayesian ResNet baseline, specify --meanfield_sdebnn.

TorchBNN: SDE-BNN in Pytorch

A PyTorch implementation of the Brax framework powered by the torchsde backend.

Usage

All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.

Toy 1D regression to learn multi-modal posterior:

python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>

Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.

Image Classification:

All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True if training with adjoint and adaptive modes).

python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True

References

[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]

[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]


If you found this library useful in your research, please consider citing

@article{xu2021sdebnn,
  title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
  author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
  archivePrefix = {arXiv},
  year={2021}
}
Owner
Winnie Xu
Undergrad in CS/Stats/Math '22 @ UToronto. Working on something secret @cohere-ai. Deep neural networks @for-ai @VectorInstitute. Prev. @google-research @NVIDIA
Winnie Xu
113 Nov 28, 2022
Cartoon-StyleGan2 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation

Fine-tuning StyleGAN2 for Cartoon Face Generation

Jihye Back 520 Jan 04, 2023
Simple image captioning model - CLIP prefix captioning.

CLIP prefix captioning. Inference Notebook: 🥳 New: 🥳 Our technical papar is finally out! Official implementation for the paper "ClipCap: CLIP Prefix

688 Jan 04, 2023
FishNet: One Stage to Detect, Segmentation and Pose Estimation

FishNet FishNet: One Stage to Detect, Segmentation and Pose Estimation Introduction In this project, we combine target detection, instance segmentatio

1 Oct 05, 2022
CTF challenges from redpwnCTF 2021

redpwnCTF 2021 Challenges This repository contains challenges from redpwnCTF 2021 in the rCDS format; challenge information is in the challenge.yaml f

redpwn 27 Dec 07, 2022
🔥🔥High-Performance Face Recognition Library on PaddlePaddle & PyTorch🔥🔥

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 04, 2023
transfer attack; adversarial examples; black-box attack; unrestricted Adversarial Attacks on ImageNet; CVPR2021 天池黑盒竞赛

transfer_adv CVPR-2021 AIC-VI: unrestricted Adversarial Attacks on ImageNet CVPR2021 安全AI挑战者计划第六期赛道2:ImageNet无限制对抗攻击 介绍 : 深度神经网络已经在各种视觉识别问题上取得了最先进的性能。

25 Dec 08, 2022
The undersampled DWI image using Slice-Interleaved Diffusion Encoding (SIDE) method can be reconstructed by the UNet network.

UNet-SIDE The undersampled DWI image using Slice-Interleaved Diffusion Encoding (SIDE) method can be reconstructed by the UNet network. For Super Reso

TIANTIAN XU 1 Jan 13, 2022
[ACL 20] Probing Linguistic Features of Sentence-level Representations in Neural Relation Extraction

REval Table of Contents Introduction Overview Requirements Installation Probing Usage Citation License 🎓 Introduction REval is a simple framework for

13 Jan 06, 2023
OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network

Stock Price Prediction of Apple Inc. Using Recurrent Neural Network OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network Dataset:

Nouroz Rahman 410 Jan 05, 2023
Fast, flexible and fun neural networks.

Brainstorm Discontinuation Notice Brainstorm is no longer being maintained, so we recommend using one of the many other,available frameworks, such as

IDSIA 1.3k Nov 21, 2022
AntiFuzz: Impeding Fuzzing Audits of Binary Executables

AntiFuzz: Impeding Fuzzing Audits of Binary Executables Get the paper here: https://www.usenix.org/system/files/sec19-guler.pdf Usage: The python scri

Chair for Sys­tems Se­cu­ri­ty 88 Dec 21, 2022
PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules

Dynamic Routing Between Capsules - PyTorch implementation PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules from Sara Sabour,

Adam Bielski 475 Dec 24, 2022
A Machine Teaching Framework for Scalable Recognition

MEMORABLE This repository contains the source code accompanying our ICCV 2021 paper. A Machine Teaching Framework for Scalable Recognition Pei Wang, N

2 Dec 08, 2021
Official code repository for Continual Learning In Environments With Polynomial Mixing Times

Official code for Continual Learning In Environments With Polynomial Mixing Times Continual Learning in Environments with Polynomial Mixing Times This

Sharath Raparthy 1 Dec 19, 2021
Code, final versions, and information on the Sparkfun Graphical Datasheets

Graphical Datasheets Code, final versions, and information on the SparkFun Graphical Datasheets. Generated Cells After Running Script Example Complete

SparkFun Electronics 102 Jan 05, 2023
Categorizing comments on YouTube into different categories.

Youtube Comments Categorization This repo is for categorizing comments on a youtube video into different categories. negative (grievances, complaints,

Rhitik 5 Nov 26, 2022
This repository contains pre-trained models and some evaluation code for our paper Towards Unsupervised Dense Information Retrieval with Contrastive Learning

Contriever: Towards Unsupervised Dense Information Retrieval with Contrastive Learning This repository contains pre-trained models and some evaluation

Meta Research 207 Jan 08, 2023
Computer vision - fun segmentation experience using classic and deep tools :)

Computer_Vision_Segmentation_Fun Segmentation of Images and Video. Tools: pytorch Models: Classic model - GrabCut Deep model - Deeplabv3_resnet101 Flo

Mor Ventura 1 Dec 18, 2021
AI-Fitness-Tracker - AI Fitness Tracker With Python

AI-Fitness-Tracker We have build a AI based Fitness Tracker using OpenCV and Pyt

Sharvari Mangale 5 Feb 09, 2022