Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Overview

Bayes-Newton

Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Will Wilkinson.

Bayes-Newton provides a unifying view of approximate Bayesian inference, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.

Installation

pip install bayesnewton

Example

Given some inputs x and some data y, you can construct a Bayes-Newton model as follows,

kern = bayesnewton.kernels.Matern52()
lik = bayesnewton.likelihoods.Gaussian()
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

The training loop (inference and hyperparameter learning) is then set up using objax's built in functionality:

lr_adam = 0.1
lr_newton = 1
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())

@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton, **inf_args)  # perform inference and update variational params
    dE, E = energy(**inf_args)  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)  # update the hyperparameters
    return E

As we are using JAX, we can JIT compile the training loop:

train_op = objax.Jit(train_op)

and then run the training loop,

iters = 20
for i in range(1, iters + 1):
    loss = train_op()

Full demos are available here.

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.

Citing Bayes-Newton

@article{wilkinson2021bayesnewton,
  title = {{B}ayes-{N}ewton Methods for Approximate {B}ayesian Inference with {PSD} Guarantees},
  author = {Wilkinson, William J. and S\"arkk\"a, Simo and Solin, Arno},
  journal={arXiv preprint arXiv:2111.01721},
  year={2021}
}

Implemented Models

For a full list of the all the models available see the model class list.

Variational GPs

  • Variationl GP (Opper, Archambeau: The Variational Gaussian Approximation Revisited, Neural Computation 2009; Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to Inference in Conjugate Models, AISTATS 2017)
  • Sparse Variational GP (Hensman, Matthews, Ghahramani: Scalable Variational Gaussian Process Classification, AISTATS 2015; Adam, Chang, Khan, Solin: Dual Parameterization of Sparse Variational Gaussian Processes, NeurIPS 2021)
  • Markov Variational GP (Chang, Wilkinson, Khan, Solin: Fast Variational Learning in State Space Gaussian Process Models, MLSP 2020)
  • Sparse Markov Variational GP (Adam, Eleftheriadis, Durrande, Artemev, Hensman: Doubly Sparse Variational Gaussian Processes, AISTATS 2020; Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Spatio-Temporal Variational GP (Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)

Expectation Propagation GPs

  • Expectation Propagation GP (Minka: A Family of Algorithms for Approximate Bayesian Inference, Ph. D thesis 2000)
  • Sparse Expectation Propagation GP (energy not working) (Csato, Opper: Sparse on-line Gaussian processes, Neural Computation 2002; Bui, Yan, Turner: A Unifying Framework for Gaussian Process Pseudo Point Approximations Using Power Expectation Propagation, JMLR 2017)
  • Markov Expectation Propagation GP (Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Expectation Propagation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Laplace/Newton GPs

  • Laplace GP (Rasmussen, Williams: Gaussian Processes for Machine Learning, 2006)
  • Sparse Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Markov Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Sparse Markov Laplace GP

Linearisation GPs

  • Posterior Linearisation GP (García-Fernández, Tronarp, Sarkka: Gaussian Process Classification Using Posterior Linearization, IEEE Signal Processing 2019; Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Sparse Posterior Linearisation GP
  • Markov Posterior Linearisation GP (García-Fernández, Svensson, Sarkka: Iterated Posterior Linearization Smoother, IEEE Automatic Control 2016; Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Posterior Linearisation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Taylor Expansion / Analytical Linearisaiton GP (Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Markov Taylor GP / Extended Kalman Smoother (Bell: The Iterated Kalman Smoother as a Gauss-Newton method, SIAM Journal on Optimization 1994)
  • Sparse Taylor GP
  • Sparse Markov Taylor GP / Sparse Extended Kalman Smoother (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Gauss-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Gauss-Newton
  • Variational Gauss-Newton
  • PEP Gauss-Newton
  • 2nd-order PL Gauss-Newton

Quasi-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Quasi-Newton
  • Variational Quasi-Newton
  • PEP Quasi-Newton
  • PL Quasi-Newton

GPs with PSD Constraints via Riemannian Gradients

  • VI Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • Newton/Laplace Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • PEP Riemann Grad (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

Others

  • Infinite Horizon GP (Solin, Hensman, Turner: Infinite-Horizon Gaussian Processes, NeurIPS 2018)
  • Parallel Markov GP (with VI, EP, PL, ...) (Särkkä, García-Fernández: Temporal parallelization of Bayesian smoothers; Corenflos, Zhao, Särkkä: Gaussian Process Regression in Logarithmic Time; Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)
  • 2nd-order Posterior Linearisation GP (sparse, Markov, ...) (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
Comments
  • Addition of a Squared Exponential Kernel?

    Addition of a Squared Exponential Kernel?

    This is not an issue per se, but I was wondering if there was a specific reason that there wasn't a Squared Exponential kernel as part of the package?

    If applicable, I would be happy to submit a PR adding one.

    Just let me know.

    opened by mathDR 7
  • Add sq exponential kernel

    Add sq exponential kernel

    This PR does two things:

    1. Adds a Squared Exponential Kernel. Adopting the gpflow nomenclature for naming, i.e. SquaredExponential inheriting from StationaryKernel. As of now, only the K_r method is populated (others may be populated as needed).
    2. The marathon.py demo was changed to take the SquaredExponential kernel (with lengthscale = 40). This serves as a "test" to ensure that the kernel runs.
    opened by mathDR 2
  • question about the equation (64) in `Bayes-Newton` paper

    question about the equation (64) in `Bayes-Newton` paper

    I'm a little confused by the equation (64). It is calculated by equation (63), but where is the other term in equation (64)? such as denominator comes from the covariance of p(fn|u).

    opened by Fangwq 2
  • jitted predict

    jitted predict

    Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,

    Regards,

    opened by soldierofhell 2
  • error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    As the title said, there is an error after running heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method:

    File "heteroscedastic.py", line 101, in train_op
    model.inference(lr=lr_newton, damping=damping)  # perform inference and update variational params
    File "/BayesNewton-main/bayesnewton/inference.py",  line 871, in inference
    mean, jacobian, hessian, quasi_newton_state =self.update_variational_params(batch_ind, lr, **kwargs)
    File "/BayesNewton-main/bayesnewton/inference.py",
    line 1076, in update_variational_params
    jacobian_var = transpose(solve(omega, dmu_dv)) @ residual
    ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(117, 1, 1) and b=(117, 2, 2)
    

    Can you tell me where it is wrong ? Thanks in advance.

    opened by Fangwq 2
  • How to set initial `Pinf` variable in kernel?

    How to set initial `Pinf` variable in kernel?

    I note that the initial Pinf variable for Matern-5/2 kernel is as follows:

            Pinf = np.array([[self.variance,    0.0,   -kappa],
                             [0.0,    kappa, 0.0],
                             [-kappa, 0.0,   25.0*self.variance / self.lengthscale**4.0]])
    

    Why it is like that? Any references I should follow up?

    PS: by the way, the data ../data/aq_data.csv is missing.

    opened by Fangwq 2
  • How to install Newt in conda virtual environment?

    How to install Newt in conda virtual environment?

    Hi, thank you for sharing your great work.

    I am a little confused about how to install Newt in a conda VE. I really appreciate it if you could guide in this regard. Thank you

    opened by mohammad-saber 2
  • How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    Just as the title said, how can I understand cavity_distribution_tied in file basemodels.py? Is there any reference I should follow up? And I note that this code is similar to equation (64) in the BayesNewton paper. How does it come from?

    opened by Fangwq 1
  • issue with SparseVariationalGP method

    issue with SparseVariationalGP method

    When I run the code file demos/regression.py with SparseVariationalGP, something wrong happens:

    AssertionError: Assignments to variable must be an instance of JaxArray, but received f<class 'numpy.ndarray'>.
    

    It seems that a mistake in method SparseVariationalGP . Can you help to fix the problem? Thank you very much!

    opened by Fangwq 1
  • Sparse EP energy is incorrect

    Sparse EP energy is incorrect

    The current implementation of the sparse EP energy is not giving sensible results. This is a reminder to look into the reasons why and check against implementations elsewhere. PRs very welcome for this issue.

    Note: the EP energy is correct for all other models (GP, Markov GP, SparseMarkovGP)

    bug 
    opened by wil-j-wil 1
  • Double Precision Issues

    Double Precision Issues

    Hi!

    Many thanks for open-sourcing this package.

    I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

    • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
    • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

    However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

    Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

    Software and hardware details:

    objax==1.6.0
    jax==0.3.13
    jaxlib==0.3.10+cuda11.cudnn805
    
    NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
    GeForce RTX 3090 GPUs
    

    Thanks in advance.

    Best, Harrison

    opened by harrisonzhu508 1
  • Cannot run demo, possible incompatibility with latest Jax

    Cannot run demo, possible incompatibility with latest Jax

    Dear all,

    I am trying to run the demo examples, but I run in the following error


    ImportError Traceback (most recent call last) Input In [22], in <cell line: 1>() ----> 1 import bayesnewton 2 import objax 3 import numpy as np

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in ----> 1 from . import ( 2 kernels, 3 utils, 4 ops, 5 likelihoods, 6 models, 7 basemodels, 8 inference, 9 cubature 10 ) 13 def build_model(model, inf, name='GPModel'): 14 return type(name, (inf, model), {})

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in 3 import jax.numpy as np 4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm ----> 5 from jax.ops import index_add, index 6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix 7 from warnings import warn

    ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)

    I think its related to this from the Jax website:

    The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

    opened by daniel-trejobanos 3
  • Latest versions of JAX and objax cause compile slow down

    Latest versions of JAX and objax cause compile slow down

    It is recommended to use the following versions of jax and objax:

    jax==0.2.9
    jaxlib==0.1.60
    objax==1.3.1
    

    This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.

    bug 
    opened by wil-j-wil 0
Releases(v1.2.0)
Owner
AaltoML
Machine learning group at Aalto University lead by Prof. Solin
AaltoML
Self-training with Weak Supervision (NAACL 2021)

This repo holds the code for our weak supervision framework, ASTRA, described in our NAACL 2021 paper: "Self-Training with Weak Supervision"

Microsoft 148 Nov 20, 2022
Portfolio analytics for quants, written in Python

QuantStats: Portfolio analytics for quants QuantStats Python library that performs portfolio profiling, allowing quants and portfolio managers to unde

Ran Aroussi 2.7k Jan 08, 2023
“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品

“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品,并且能够返回完整地购物清单及顾客应付的实际商品总价格,极大地降低零售行业实际运营过程中巨大的人力成本,提升零售行业无人化、自动化、智能化水平。

thomas-yanxin 192 Jan 05, 2023
Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation (CVPR 2021)

Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation Input Image Initial CAM Successive Maps with adversar

Jungbeom Lee 110 Dec 07, 2022
Official Implementation of "Tracking Grow-Finish Pigs Across Large Pens Using Multiple Cameras"

Multi Camera Pig Tracking Official Implementation of Tracking Grow-Finish Pigs Across Large Pens Using Multiple Cameras CVPR2021 CV4Animals Workshop P

44 Jan 06, 2023
Pipeline code for Sequential-GAM(Genome Architecture Mapping).

Sequential-GAM Pipeline code for Sequential-GAM(Genome Architecture Mapping). mapping whole_preprocess.sh include the whole processing of mapping. usa

3 Nov 03, 2022
MVP Benchmark for Multi-View Partial Point Cloud Completion and Registration

MVP Benchmark: Multi-View Partial Point Clouds for Completion and Registration [NEWS] 2021-07-12 [NEW 🎉 ] The submission on Codalab starts! 2021-07-1

PL 93 Dec 21, 2022
PyTorch code of "SLAPS: Self-Supervision Improves Structure Learning for Graph Neural Networks"

SLAPS-GNN This repo contains the implementation of the model proposed in SLAPS: Self-Supervision Improves Structure Learning for Graph Neural Networks

60 Dec 22, 2022
A list of multi-task learning papers and projects.

This page contains a list of papers on multi-task learning for computer vision. Please create a pull request if you wish to add anything. If you are interested, consider reading our recent survey pap

svandenh 297 Dec 17, 2022
MPI Interest Group on Algorithms on 1st semester 2021

MPI Algorithms Interest Group Introduction Lecturer: Steve Yan Location: TBA Time Schedule: TBA Semester: 1 Useful URLs Typora: https://typora.io Goog

Ex10si0n 13 Sep 08, 2022
Nerf pl - NeRF (Neural Radiance Fields) and NeRF in the Wild using pytorch-lightning

nerf_pl Update: an improved NSFF implementation to handle dynamic scene is open! Update: NeRF-W (NeRF in the Wild) implementation is added to nerfw br

AI葵 1.8k Dec 30, 2022
A TensorFlow implementation of DeepMind's WaveNet paper

A TensorFlow implementation of DeepMind's WaveNet paper This is a TensorFlow implementation of the WaveNet generative neural network architecture for

Igor Babuschkin 5.3k Dec 28, 2022
[TIP2020] Adaptive Graph Representation Learning for Video Person Re-identification

Introduction This is the PyTorch implementation for Adaptive Graph Representation Learning for Video Person Re-identification. Get started git clone h

WuYiming 41 Dec 12, 2022
This repository includes the official project for the paper: TransMix: Attend to Mix for Vision Transformers.

TransMix: Attend to Mix for Vision Transformers This repository includes the official project for the paper: TransMix: Attend to Mix for Vision Transf

Jie-Neng Chen 130 Jan 01, 2023
Контрольная работа по математическим методам машинного обучения

ML-MathMethods-Test Контрольная работа по математическим методам машинного обучения. Вычисление основных статистик, диаграмм и графиков, проверка разл

Stas Ivanovskii 1 Jan 06, 2022
Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification

This repo holds the codes of our paper: Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification, which is ac

Feng Gao 17 Dec 28, 2022
Official implementation of the paper Visual Parser: Representing Part-whole Hierarchies with Transformers

Visual Parser (ViP) This is the official implementation of the paper Visual Parser: Representing Part-whole Hierarchies with Transformers. Key Feature

Shuyang Sun 117 Dec 11, 2022
Unofficial pytorch-lightning implement of Mip-NeRF

mipnerf_pl Unofficial pytorch-lightning implement of Mip-NeRF, Here are some results generated by this repository (pre-trained models are provided bel

Jianxin Huang 159 Dec 23, 2022
Dynamics-aware Adversarial Attack of 3D Sparse Convolution Network

Leaded Gradient Method (LGM) This repository contains the PyTorch implementation for paper Dynamics-aware Adversarial Attack of 3D Sparse Convolution

An Tao 2 Oct 18, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022