Equinox
Callable PyTrees and filtered JIT/grad transformations
=> neural networks in JAX
Equinox brings more power to your model building in JAX.
Represent parameterised functions as data, and use filtered transformations for powerful fine-grained control of the model-building process.
Equinox is half tech-demo, half neural network library.
Equinox in brief
Building neural networks
Build models using a PyTorch-like class based API without sacrificing JAX-like functional programming.
In particular, without extra complexity like class-to-functional transformations, custom notions of parameter groups, or specially wrapped library.jits and library.grads, like many libraries have.
Equinox is a tiny library -- no behind-the-scenes magic, guaranteed. The elegance of Equinox is its selling point in a world that already has Haiku, Flax etc.
Technical contributions
Equinox represents parameterised functions as data. That is, you can represent your whole model (parameters, buffers, forward pass, etc.) as a PyTree. Parameterised functions can be passed in and out of higher-order functions -- like passing models to jax.vmap, vmap'd functions to loss functions, or loss functions to JIT and grad.
Equinox additionally offers thin wrappers around jax.jit/jax.grad that understand the PyTree structure of their inputs: you can JIT/differentiate a single leaf, not just a whole argument. (We don't offer this for jax.vmap because interestingly jax.vmap offers this already.)
There's some similarities to existing libraries (like the structs of flax.linen or the functors of Flux.jl), but to the best of my knowledge Equinox offers something genuinely new to the JAX framework.
Installation
pip install git+https://github.com/patrick-kidger/equinox.git
Requires Python 3.7+ and JAX 0.2.18+.
Quick example
import equinox as eqx
import functools as ft, jax, jax.numpy as jnp, jax.random as jrandom
# Define our model. `Module` subclasses are both functions and data, so we can pass them into higher
# order functions like vmap/jit/grad, or our loss function later.
# There's no magic in `Module`. Pretty much all it does is just register your class as PyTree node.
class LinearOrIdentity(eqx.Module):
weight: jnp.ndarray
flag: bool
def __init__(self, in_features, out_features, flag, key):
self.weight = jrandom.normal(key, (out_features, in_features))
self.flag = flag
def __call__(self, x):
if self.flag:
return x
return self.weight @ x
# We use the fact that our model is data, by passing it in as an argument to the loss.
# There's no magic here: `model` is a PyTree like any other.
#
# We use filtered transformations to unpack its data and select just the leaves we want to
# JIT+differentiate. (In this case, all floating-point JAX arrays -- `weight` but not `flag`.)
# There's no magic here: filtered transformations act on any kind of PyTree.
#
# Equinox is JAX-friendly. If you want to differentiate everything, just use `jax.jit` and `jax.grad`.
@ft.partial(eqx.jitf, filter_fn=eqx.is_inexact_array)
@ft.partial(eqx.gradf, filter_fn=eqx.is_inexact_array)
def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jnp.mean((y - pred_y) ** 2)
modelkey, xkey, ykey = jrandom.split(jrandom.PRNGKey(0), 3)
model = LinearOrIdentity(2, 3, flag=False, key=modelkey)
x, y = jrandom.normal(xkey, (100, 2)), jrandom.normal(ykey, (100, 3))
grads = loss(model, x, y)
This quick example exposes you to the two main concepts in Equinox: callable PyTrees and filtered transformations. Together, they're very powerful.
Callable PyTrees
This is just some methods attached to a PyTree. (In this case it's the __call__ method of a Module subclass.) All subclassing Module really does is just automatically register your class with JAX as a custom PyTree node; there's no magic here.
The PyTree structure holds the data (parameters, buffers, submodules, boolean flags, even arbitrary Python objects). The methods on the class define operations parameterised by that data -- in this case and in particular, the forward pass through a model.
This gives a way to represent parameterised functions as data: and as such, they're suitable for passing in and out of JAX functions. This is what we do when passing the model instance to the loss function.
Footnote: callable PyTrees actually aren't anything special -- the build-in Python methods on lists and dictionaries are another example of callable PyTrees.
Filtered transformations
The one issue with putting everything about a model into a single PyTree is that this might not contain just trainable parameters. The above example includes a boolean flag, for example. We certainly can't differentiate this, and we may or may not wish to JIT trace/static this.
In general we might have arbitrary Python objects, or perhaps JAX arrays that are buffers rather than trainable parameters.
Enter filtered transformations. These are equinox.jitf and equinox.gradf, which are very thin wrappers around jax.jit and jax.grad. Instead of specifying argnums to JIT/differentiate, we instead pass a filter that determines which PyTree leaves -- not just whole arguments -- to JIT/differentiate.
These aren't "a way to make JIT/grad work with model states" like many libraries have. They are general operations on PyTrees, and nothing about Module is special-cased.
- For one thing, we don't need to special-case anything:
Moduleis just a PyTree like any other. - For another, if you don't want to filter out anything at all, then don't: use
jax.jitandjax.graddirectly and they'll work just fine.
This gives a powerful fine-grained way control JIT and autodifferentiation.
Integrates smoothly with JAX
There's nothing special about Equinox modules. They're just PyTrees.
There's nothing special about filtered transformations. They just operate on PyTrees.
Equinox is all just regular JAX -- PyTrees and transformations! Together, these two pieces allow us to specify complex models in JAX-friendly ways.
Examples
-
train_mlp.pygives a short example that introducesequinox.jitfandequinox.gradf. These will be used to select the parameters of an MLP and train them. -
frozen_layer.pydemonstrates how this approach really shines: some of the parameters will be trained, some of them will be frozen, but all of them will be efficiently JIT-traced. -
build_model.pydemonstrates how to build parameterised-functions-as-data usingequinox.Module. In particular we'll construct an MLP from scratch, and then pass it into higher-order functions like JIT and grad in order to train it. This allows us to produce models using a familiar class-based syntax, that are also functional and integrate directly with JAX's JIT/autograd. -
train_rnn.pytrains an RNN on a toy clockwise/anticlockwise spiral classification problem. This demonstrates the use ofjax.lax.scanwith Equinox. (It just works, no tricks required.)
API
Full API list
# Filtered transformations # Filters
equinox.jitf equinox.is_inexact_array
equinox.gradf equinox.is_array_like
equinox.value_and_grad_f
# Neural networks
# Module equinox.nn.Linear
equinox.Module equinox.nn.Identity
equinox.nn.Dropout
# Utilities equinox.nn.GRUCell
equinox.apply_updates equinox.nn.LSTMCell
equinox.tree_at equinox.nn.Sequential
equinox.tree_equal equinox.nn.MLP
Filtered transformations
equinox.jitf(fun, *, filter_fn=None, filter_tree=None, **kwargs)
Wraps jax.jit.
funis a pure function to JIT compile.filter_fnis a callableAny -> bool. It will be called on every leaf of every PyTree that is inputted tofun. If it returnsTrue, the leaf will be traced. It returnsFalse, the leaf with be treated as static. Mutually exclusive withfilter_tree.filter_treeis a tree, or tuple of trees, of the same length as the number of inputs. (Or ifstatic_argnumsis passed, the number of inputs not already marked static viastatic_argnums.) It must have the exact same tree structure as the inputs. Every leaf must be eitherTrueorFalse. Each leaf offilter_treeis matched up against the corresponding input: if it isTruethe leaf will be traced; it it isFalsethe leaf will be treated as static. Mutually exclusive withfilter_tree.**kwargsare the usual other arguments tojax.jit, likestatic_argnums. In particular, a leaf will be marked static if either (a) it is filtered as being so, or (b) it is part of a PyTree that is marked throughstatic_argnums.
Precisely one of filter_fn or filter_tree must be passed.
See also equinox.is_array_like as usually a good choice of filter_fn: this will trace everything that can possible be traced, with everything else static.
See also equinox.tree_at for an easy way to create the filter_tree argument.
equinox.gradf(fun, *, filter_fn=None, filter_tree=None, **kwargs)
Wraps jax.grad.
funis a pure function to JIT compile.filter_fnis a callableAny -> bool. It will be called on every leaf of every PyTree that is marked as potentially requiring gradient viaargnums. If it returnsTrue, the leaf will be differentiated. If it returnsFalse, the leaf will not be differentiated. Mutually exclusive withfilter_tree.filter_treeis a tree, or tuple of trees, of the same length as the number of inputs marked as potentially requiring gradient viaargnums. It must have the exact same tree structure as the inputs. Every leaf must be eitherTrueorFalse. Each leaf offilter_treeis matched up against the corresponding input: if it isTruethe leaf will be differentiated; if it isFalsethe leaf will not be differentiated. Mutually exclusive withfilter_fn.**kwargsare the usual other argments tojax.grad, likeargnums. In particular, a leaf will only be differentiated if (a) it is filtered as being so, and (b) it is part of a PyTree that is marked throughargnums.
Precisely one of filter_fn or filter_tree must be passed.
See also equinox.is_inexact_array as usually a good choice of filter_fn: this will differentiate all floating-point arrays.
See also equinox.tree_at for an easy way to create the filter_tree argument.
Note that as the returned gradients must have the same structure as the inputs, then all nondifferentiable components of the input PyTrees will have gradient None. Doing a simple jax.tree_map(lambda m, g: m - lr * g, model, grad) will fail. As such Equinox provides equinox.apply_updates as a simple convenience: it will only apply the update if the gradient is not None. See below.
equinox.value_and_grad_f(fun, *, filter_fn=None, filter_tree=None, **kwargs)
Wraps jax.value_and_grad. Arguments are as equinox.gradf.
Filters
Any function Any -> bool can be used as a filter. We provide some convenient common choices.
equinox.is_inexact_array(element)
Returns True if element is a floating point JAX array (but not a NumPy array).
equinox.is_array_like(element)
Returns True if element can be interpreted as a JAX array. (i.e. does jax.numpy.array throw an exception or not.)
Module
equinox.Module
Base class; create your model by inheriting from this.
Specify all its attributes at the class level (identical to dataclasses). This defines its children in the PyTree.
class MyModule(equinox.Module):
weight: typing.Any
bias: typing.Any
submodule: Module
In this case a default __init__ method is provided, which just fills in these attributes with the argments passed: MyModule(weight, bias, submodule) or MyModule(weight=weight, bias=bias, submodule=submodule). Alternatively you can provide an __init__ method yourself. (For example to specify dimension sizes instead of raw weights.) By the end of __init__, every attribute must have been assigned.
class AnotherModule(equinox.Module):
weight: Any
def __init__(self, input_size, output_size, key):
self.weight = jax.random.normal(key, (output_size, input_size))
After initialisation then attributes cannot be modified: models are immutable as per functional programming. (Parameter updates are made by creating a new model, not by mutating parameters in-place; see for example train_mlp.py.)
It is typical to also create some methods on the class. As self will be an input parameter -- treated as a PyTree -- then these methods will get access to the attributes of the instance. Defining __call__ gives an easy way to define a forward pass for a model:
class LinearWithoutBias(equinox.Module):
weight: Any
def __call__(self, x):
return self.weight @ x
If defining a method meth, then take care not to write instance = MyModule(...); jax.jit(instance.meth)(...). (Or similarly with jax.grad, equinox.jitf etc.) This is because instance.meth is not a pure function as it already has the self parameter passed implicitly. Instead do either jax.jit(MyModule.meth)(instance, ...) or
@jax.jit
def func(instance, args):
instance.meth(args)
# Also use this pattern with instance(args) if you defined `__call__` instead of `meth`.
Utilities
equinox.apply_updates(model, updates)
Performs a training update to a model.
modelmust be a PyTree;updatesmust be a PyTree with the same structure.
It essentially performs jax.tree_map(lambda m, u: m + u, model, updates). However anywhere updates is None then no update is made at all, so as to handle nondifferentiable parts of model.
The returned value is the updated model. (model is not mutated in place, as is usual in JAX and functional programming.)
To produce updates, it is typical to take the gradients from the loss function, and then adjust them according to any standard optimiser; for example Optax provides optax.sgd or optax.adam.
equinox.tree_at(where, pytree, replace=_sentinel, replace_fn=_sentinel)
Modifies an existing tree, and returns the modified tree. (Like .at for "in place modifications" of JAX arrays.)
whereis a callablePyTree -> LeaforPyTree -> Tuple[Leaf, ...]. It should consume a PyTree of the same shape aspytree, and return the leaf or leaves that should be replaced. For examplewhere=lambda mlp: mlp.layers[-1].linear.weight.pytreeis the existing PyTree to modify.replaceshould either be a single element, or a tuple of the same length as returned bywhere. This specifies the replacements to make at the locations specified bywhere. Mutually exclusive withreplace_fn.replace_fnshould be a functionLeaf -> Any. It will be called on every leaf replaced usingwhere. The return value fromreplace_fnwill be used in its place. Mutually exclusive withreplace.
For example this can be used to specify the weights of a model to train or not train:
trainable = jax.tree_map(lambda _: False, model)
trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True)
equinox.gradf(..., filter_tree=trainable)
equinox.tree_equal(*pytrees)
Returns True if all PyTrees in the list are equal. All arrays must have the same shape, dtype, and values. JAX arrays and NumPy arrays are not considered equal.
Neural network library
Equinox includes a small neural network library, mostly as a tech demo for how the rest of the library can be used. Its API is modelled after PyTorch.
equinox.nn.Linear(in_features, out_features, bias=True, *, key)(input)
equinox.nn.Identity(*args, **kwargs)(input) # args and kwargs are ignored
equinox.nn.Dropout(p=0.5, deterministic=False)(input, *, key=None, deterministic=None)
equinox.nn.GRUCell(input_size, hidden_size, bias=True, *, key)(input, hidden)
equinox.nn.LSTMCell(input_size, hidden_size, bias=True, *, key)(input, hidden)
equinox.nn.Sequential(layers)(input, *, key=None)
equinox.nn.MLP(in_size, out_size, width_size, depth,
activation=jax.nn.relu, final_activation=lambda x: x, *, key)(input)
These all behave in the way you expect. The key arguments are used to generate the random initial weights, or to generate randomness on the forward pass of stochastic layers like Dropout.
The Dropout(deterministic=...)(deterministic=...) options determines whether to have the layer act as the identity function, as is commonly done with dropout during inference time. The call-time deterministic takes precendence if it passed; otherwise the init-time deterministic is used. (Note that because models are PyTrees, you can modify the init-time deterministic flag using equinox.tree_at. This is perfectly fine, and might be handy if it's easier than using the call-time flag.)
The MLP(final_activation=...) option determines any final activation function to apply after the last layer. (In some cases it is desirable for this to be different to the activation used in the main part of the network.)