JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library

Overview

JAX bindings to FINUFFT

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library is currently CPU-only, but GPU support is in the works using the cuFINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

For now, only a source build is supported.

For building, you should only need a recent version of Python (>3.6) and FFTW. At runtime, you'll need numpy, scipy, and jax. To set up such an environment, you can use conda (but you're welcome to use whatever workflow works for you!):

conda create -n jax-finufft -c conda-forge python=3.9 numpy scipy fftw
python -m pip install "jax[cpu]"

Then you can install from source using (don't forget the --recursive flag because FINUFFT is included as a submodule):

git clone --recursive https://github.com/dfm/jax-finufft
cd jax-finufft
python -m pip install .

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Comments
  • batching issue

    batching issue

    Hi,

    I get the following error when I try to batch nufft2 in Jax.

    process_primitive(self, primitive, tracers, params) 161 frame = self.get_frame(vals_in, dims_in) 162 batched_primitive = self.get_primitive_batcher(primitive, frame) --> 163 val_out, dim_out = batched_primitive(vals_in, dims_in, **params) 164 if primitive.multiple_results: 165 return map(partial(BatchTracer, self), val_out, dim_out)

    TypeError: batch() got an unexpected keyword argument 'output_shape'

    it seems like this is caused by

    nufft2(source, iflag, eps, *points) 57 58 return jnp.reshape( ---> 59 nufft2_p.bind(source, *points, output_shape=None, iflag=iflag, eps=eps), 60 expected_output_shape, 61 )

    Is there something I am doing wrong?

    Thanks for your help!

    opened by samaktbo 13
  • problem batching functions that have multiple arguments

    problem batching functions that have multiple arguments

    Hi,

    I am posting this here to give a clearer description of the problem that I am having.

    When I run the following snippet, I would like to have A be a 4 by 100 array whose I-th row is the output of linear_func(q, X[I,:]).

    import numpy as np
    import jax.numpy as jnp
    from jax import vmap
    from jax_finufft import nufft2
    
    rng = np.random.default_rng(seed=314)
    
    d=10 
    L_tilde = 10
    L = 100
    
    qr = rng.standard_normal((d, L_tilde+1))
    qi = rng.standard_normal((d, L_tilde+1))
    q = jnp.array(qr + 1j * qi)
    X = jnp.array(rng.uniform(low=0.0, high=1.0, size=(4, L)))
    
    def linear_func(q, x):
      v = jnp.ones(shape=(1, L_tilde))
    
      return jnp.matmul(v, nufft2(q, x, eps=1e-6, iflag=-1))
    
    batched = vmap(linear_func, in_axes=(None, 0), out_axes=0)
    
    A = batched(q, X)
    

    However, when I run the snippet I get the error posted below. You had said last time that there could be a work around for unbatched arguments but I could not figure it out.

    Here is the error:

    UnfilteredStackTrace                      Traceback (most recent call last)
    <ipython-input-25-c23bc4fa7eb4> in <module>()
    ----> 1 test = batched(q, X)
    
    30 frames
    UnfilteredStackTrace: TypeError: '<' not supported between instances of 'NoneType' and 'int'
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/jax_finufft/ops.py in <genexpr>(.0)
        281     else:
        282         mx = args[0].ndim - ndim - 1
    --> 283     assert all(a < mx for a in axes)
        284     assert all(a == axes[0] for a in axes[1:])
        285     return prim.bind(*args, **kwargs), axes[0]
    
    TypeError: '<' not supported between instances of 'NoneType' and 'int'
    

    I hope this gives a clearer picture than what I had last time. Thanks so much for your help!

    opened by samaktbo 5
  • Installation issue

    Installation issue

    I am having trouble installing jax-finuff even with the instructions on the home page. The installation fails with

    ERROR: Failed building wheel for jax-finufft Failed to build jax-finufft ERROR: Could not build wheels for jax-finufft, which is required to install pyproject.toml-based projects

    opened by samaktbo 5
  • Error when differentiating nufft1 with respect to points only

    Error when differentiating nufft1 with respect to points only

    Hi Dan,

    First, thank you for releasing this package, I was very glad to find it!

    I noted the error below when attempting to differentiate nufft1 with respect to points. I am very new to JAX so I could be mistaken, but I don't believe this is the intended behavior:

    from jax_finufft import nufft1, nufft2
    import numpy as np
    import jax.numpy as jnp
    from jax import grad
    M = 100000
    N = 200000
    
    x = 2 * np.pi * np.random.uniform(size=M)
    c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
    
    def norm_nufft1(c,x):
        f = nufft1(N, c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    def norm_nufft2(c,x):
        f = nufft2( c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    grad(norm_nufft2,argnums =(1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(1))(c,x) # Throws error
    

    The error is below:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in <module>
         22 grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
         23 grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    ---> 24 grad(norm_nufft1,argnums =(1))(c,x) # Throws error
         25 
         26 
    
        [... skipping hidden 10 frame]
    
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in norm_nufft1(c, x)
         11 
         12 def norm_nufft1(c,x):
    ---> 13     f = nufft1(N, c, x, eps=1e-6, iflag=1)
         14     return jnp.linalg.norm(f)
         15 
    
        [... skipping hidden 21 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in nufft1(output_shape, source, iflag, eps, *points)
         39 
         40     return jnp.reshape(
    ---> 41         nufft1_p.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps),
         42         expected_output_shape,
         43     )
    
        [... skipping hidden 3 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in jvp(type_, prim, args, tangents, output_shape, iflag, eps)
        248 
        249         axis = -2 if type_ == 2 else -ndim - 1
    --> 250         output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis)
        251 
        252         expand_shape = (
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in concatenate(arrays, axis)
       3405   if hasattr(arrays[0], "concatenate"):
       3406     return arrays[0].concatenate(arrays[1:], axis)
    -> 3407   axis = _canonicalize_axis(axis, ndim(arrays[0]))
       3408   arrays = _promote_dtypes(*arrays)
       3409   # lax.concatenate can be slow to compile for wide concatenations, so form a
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/util.py in canonicalize_axis(axis, num_dims)
        277   axis = operator.index(axis)
        278   if not -num_dims <= axis < num_dims:
    --> 279     raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
        280   if axis < 0:
        281     axis = axis + num_dims
    
    ValueError: axis -2 is out of bounds for array of dimension 1
    
    bug 
    opened by ma-gilles 4
  • Speed up `vmap`s where none of the points are batched

    Speed up `vmap`s where none of the points are batched

    If none of the points are batched in a vmap, we should be able to get a faster computation by stacking the transforms into a single transform and then reshaping. It might be worth making the stacked axes into an explicit parameter rather than just trying to infer it, but that would take some work on the interface.

    opened by dfm 0
  • WIP: Adding support for non-batched dimensions in `vmap`

    WIP: Adding support for non-batched dimensions in `vmap`

    In most cases, we'll just need to broadcast all the inputs out to the right shapes (this shouldn't be too hard), but when none of the points get mapped, we can get a bit of a speed up by stacking the transforms. This starts to implement that logic, but it's not quite ready yet.

    opened by dfm 0
  • Find a publication venue

    Find a publication venue

    @lgarrison and I have been chatting about the possibility of writing a paper describing what we're doing here. Something like "Differentiable programming with NUFFTs" or "NUFFTs for machine learning applications" or something. This issue is here to remind me to look into possible venues for this.

    opened by dfm 5
  • Starting to add GPU support using cuFINUFFT

    Starting to add GPU support using cuFINUFFT

    So far I just have the CMake definitions to compile cuFINUFFT when nvcc is found, but I haven't started writing the boilerplate needed to loop it into XLA. Coming soon!

    Keeping @lgarrison in the loop.

    opened by dfm 10
  • Support Type 3?

    Support Type 3?

    There aren't currently any plans to support the Type 3 transform for a few reasons:

    • I'm not totally sure of the use cases,
    • The logic will probably be somewhat more complicated than the existing implementations, and
    • cuFINUFFT doesn't seem to support Type 3.

    Do we want to work around these issues?

    opened by dfm 0
  • Add support for handling errors

    Add support for handling errors

    Currently, if any of the finufft methods fail with a non-zero error we just ignore it and keep on trucking. How should we handle this? It looks like the JAX convention is currently to set everything to NaN when an op fails, since propagating errors up from XLA can be a bit of a pain.

    opened by dfm 0
  • Add GPU support

    Add GPU support

    Via cufinufft. We'll need to figure out how to get XLA's handling of CUDA streams to play nice with cufinufft (this is way above my pay grade). Some references:

    1. A simple example of how to write an XLA compatible CUDA kernel: https://github.com/dfm/extending-jax/blob/main/lib/kernels.cc.cu
    2. The source code for how JAX wraps cuBLAS: https://github.com/google/jax/blob/main/jaxlib/cublas_kernels.cc
    3. There is a tensorflow implementation that might have some useful context: https://github.com/mrphys/tensorflow-nufft/blob/master/tensorflow_nufft/cc/kernels/nufft_kernels.cu.cc

    Perhaps @lgarrison is interested :D

    opened by dfm 0
Releases(v0.0.3)
  • v0.0.3(Dec 10, 2021)

    What's Changed

    • Fix segfault when batching multiple transforms by @dfm in https://github.com/dfm/jax-finufft/pull/11
    • Generalize the behavior of vmap by @dfm in https://github.com/dfm/jax-finufft/pull/12

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.2...v0.0.3

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Nov 12, 2021)

    • Faster differentiation using stacked transforms
    • Better error checking for vmap

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.1...v0.0.2

    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Nov 8, 2021)

Owner
Dan Foreman-Mackey
Dan Foreman-Mackey
Pytorch implementation of Nueral Style transfer

Nueral Style Transfer Pytorch implementation of Nueral style transfer algorithm , it is used to apply artistic styles to content images . Content is t

Abhinav 9 Oct 15, 2022
MMDetection3D is an open source object detection toolbox based on PyTorch

MMDetection3D is an open source object detection toolbox based on PyTorch, towards the next-generation platform for general 3D detection. It is a part of the OpenMMLab project developed by MMLab.

OpenMMLab 3.2k Jan 05, 2023
Official Implementation for Fast Training of Neural Lumigraph Representations using Meta Learning.

Fast Training of Neural Lumigraph Representations using Meta Learning Project Page | Paper | Data Alexander W. Bergman, Petr Kellnhofer, Gordon Wetzst

Alex 39 Oct 08, 2022
Trajectory Variational Autoencder baseline for Multi-Agent Behavior challenge 2022

MABe_2022_TVAE: a Trajectory Variational Autoencoder baseline for the 2022 Multi-Agent Behavior challenge This repository contains jupyter notebooks t

Andrew Ulmer 15 Nov 08, 2022
Automatic self-diagnosis program (python required)Automatic self-diagnosis program (python required)

auto-self-checker 자동으로 자가진단 해주는 프로그램(python 필요) 중요 이 프로그램이 실행될때에는 절대로 마우스포인터를 움직이거나 키보드를 건드리면 안된다(화면인식, 마우스포인터로 직접 클릭) 사용법 프로그램을 구동할 폴더 내의 cmd창에서 pip

1 Dec 30, 2021
A curated list of neural rendering resources.

Awesome-of-Neural-Rendering A curated list of neural rendering and related resources. Please feel free to pull requests or open an issue to add papers

Zhiwei ZHANG 43 Dec 09, 2022
This program will stylize your photos with fast neural style transfer.

Neural Style Transfer (NST) Using TensorFlow Demo TensorFlow TensorFlow is an end-to-end open source platform for machine learning. It has a comprehen

Ismail Boularbah 1 Aug 08, 2022
Video Matting Refinement For Python

Video-matting refinement Library (use pip to install) scikit-image numpy av matplotlib Run Static background python path_to_video.mp4 Moving backgroun

3 Jan 11, 2022
Python Implementation of Chess Playing AI with variable difficulty

Chess AI with variable difficulty level implemented using the MiniMax AB-Pruning Algorithm

Ali Imran 7 Feb 20, 2022
Learning to Communicate with Deep Multi-Agent Reinforcement Learning in PyTorch

Learning to Communicate with Deep Multi-Agent Reinforcement Learning This is a PyTorch implementation of the original Lua code release. Overview This

Minqi 297 Dec 12, 2022
Data loaders and abstractions for text and NLP

torchtext This repository consists of: torchtext.datasets: The raw text iterators for common NLP datasets torchtext.data: Some basic NLP building bloc

3.2k Jan 08, 2023
Implementation of CSRL from the AAAI2022 paper: Constraint Sampling Reinforcement Learning: Incorporating Expertise For Faster Learning

CSRL Implementation of CSRL from the AAAI2022 paper: Constraint Sampling Reinforcement Learning: Incorporating Expertise For Faster Learning Python: 3

4 Apr 14, 2022
Camera ready code repo for the NeuRIPS 2021 paper: "Impression learning: Online representation learning with synaptic plasticity".

Impression-Learning-Camera-Ready Camera ready code repo for the NeuRIPS 2021 paper: "Impression learning: Online representation learning with synaptic

2 Feb 09, 2022
Pytorch code for semantic segmentation using ERFNet

ERFNet (PyTorch version) This code is a toolbox that uses PyTorch for training and evaluating the ERFNet architecture for semantic segmentation. For t

Edu 394 Jan 01, 2023
[ICCV 2021] FaPN: Feature-aligned Pyramid Network for Dense Image Prediction

FaPN: Feature-aligned Pyramid Network for Dense Image Prediction [arXiv] [Project Page] @inproceedings{ huang2021fapn, title={{FaPN}: Feature-alig

EMI-Group 175 Dec 30, 2022
Segcache: a memory-efficient and scalable in-memory key-value cache for small objects

Segcache: a memory-efficient and scalable in-memory key-value cache for small objects This repo contains the code of Segcache described in the followi

TheSys Group @ CMU CS 78 Jan 07, 2023
Custom TensorFlow2 implementations of forward and backward computation of soft-DTW algorithm in batch mode.

Batch Soft-DTW(Dynamic Time Warping) in TensorFlow2 including forward and backward computation Custom TensorFlow2 implementations of forward and backw

19 Aug 30, 2022
This app is a simple example of using Strealit to create a financial data web app.

Streamlit Demo: Finance Chart This app is a simple example of using Streamlit to create a financial data web app. This demo use streamlit, pandas and

91 Jan 02, 2023
An easier way to build neural search on the cloud

An easier way to build neural search on the cloud Jina is a deep learning-powered search framework for building cross-/multi-modal search systems (e.g

Jina AI 17k Jan 02, 2023
Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021

This folder contains the code for 'Scalable Variational Approaches for Bayesian Causal Discovery'. Installation To install, use conda with conda env c

14 Sep 21, 2022