JMP is a Mixed Precision library for JAX.

Related tags

Machine Learningjmp
Overview

Mixed precision training in JAX

Test status PyPI version

Installation | Examples | Policies | Loss scaling | Citing JMP | References

Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.

This library implements support for mixed precision training in JAX by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as Haiku) can integrate with jmp and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules).

All code examples below assume the following:

import jax
import jax.numpy as jnp
import jmp

half = jnp.float16  # On TPU this should be jnp.bfloat16.
full = jnp.float32

Installation

JMP is written in pure Python, but depends on C++ code via JAX and NumPy.

Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install JMP using pip:

$ pip install git+https://github.com/deepmind/jmp

Examples

You can find a fully worked JMP example in Haiku which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third.

Policies

A mixed precision policy encapsulates the configuration in a mixed precision experiment.

# Our policy specifies that we will store parameters in full precision but will
# compute and return output in half precision.
my_policy = jmp.Policy(compute_dtype=half,
                       param_dtype=full,
                       output_dtype=half)

The policy object can be used to cast pytrees:

def layer(params, x):
  params, x = my_policy.cast_to_compute((params, x))
  w, b = params
  y = x @ w + b
  return my_policy.cast_to_output(y)

params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half

You can replace the output type of a given policy:

my_policy = my_policy.with_output_dtype(full)

You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment:

my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
float16 = jmp.get_policy("float16")  # Everything in f16.
half = jmp.get_policy("half")        # Everything in half (f16 or bf16).

Loss scaling

When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with float16 and less important for bfloat16. See the NVIDIA mixed precision user guide [1] for more details.

The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by S and 1/S respectively.

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.
  params = apply_optimizer(params, grads)
  return params

loss_scale = jmp.StaticLossScale(2 ** 15)
for _ in range(num_steps):
  params = train_step(params, loss_scale, ...)

The appropriate value for S depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of S that does not introduce overflow during backprop. NVIDIA [1] recommend computing statistics about the gradients of your model (in full precision) and picking S such that its product with the maximum norm of your gradients is below 65,504.

We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for S that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%).

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.

  # You definitely want to skip non-finite updates with the dynamic loss scale,
  # but you might also want to consider skipping them when using a static loss
  # scale if you experience NaN's when training.
  skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)

  if skip_nonfinite_updates:
    grads_finite = jmp.all_finite(grads)
    # Adjust our loss scale depending on whether gradients were finite. The
    # loss scale will be periodically increased if gradients remain finite and
    # will be decreased if not.
    loss_scale = loss_scale.adjust(grads_finite)
    # Only apply our optimizer if grads are finite, if any element of any
    # gradient is non-finite the whole update is discarded.
    params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)
  else:
    # With static or no loss scaling just apply our optimizer.
    params = apply_optimizer(params, grads)

  # Since our loss scale is dynamic we need to return the new value from
  # each step. All loss scales are `PyTree`s.
  return params, loss_scale

loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
for _ in range(num_steps):
  params, loss_scale = train_step(params, loss_scale, ...)

In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue.

We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the jmp.LossScale API):

loss_scale = jmp.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1

Citing JMP

This repository is part of the DeepMind JAX Ecosystem, to cite JMP please use the DeepMind JAX Ecosystem citation.

References

[0] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740.

[1] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/.

Comments
  • Questions around speedup

    Questions around speedup

    Hi,

    Thanks for creating this amazing library!

    So from what I understood this is the minimal needed to benefit from the JMP speedup

    policy = jmp.Policy(param_dtype=jnp.float32, compute_dtype=jnp.float16, output_dtype=jnp.float16)
    ...
    # creating a network, and creating a loss_fn using the network
    data = policy.cast_to_compute(data)
    params = policy.cast_to_compute(params)
    
    grads = jax.grad(loss_fn)(data, params)
    ...
    grads = policy.cast_to_param(grads)
    

    The loss scale is only needed if we experience NaN or inf values. And thus we should see a difference in the time needed for a jax.grad operation when the data and params are in float16.

    Please correct me if I'm wrong or if I miss anything.

    Considering that, I have some questions:

    • Can we expect to see a 2x or any speedup if we time the jax.grad operation with (params, data) in float16 and in float32 or is the speedup only achieved at scale ?
    • Can we expect to see a 2x or any speedup with small networks and small experiments ? (say 2-layer nets and experiments that only need 2-3 minutes on a single gpu)
    • Can we expect to see a 2x or any speedup with all types of networks or only some with specific architectures ? (say a 50 layer MLP, Transformers, etc)
    • Is it necessary to apply the mixed precision policy to the network and thus to use Haiku for hk.mixed_precision.set_policy or is having the parameters and data in float16 sufficient to have a speedup even with a network created by us ?

    Thank you and have a nice day !

    opened by 1m1ne 3
  • Basic question about the use of my_policy

    Basic question about the use of my_policy

    Hi,

    Thanks for creating this amazing functionality!

    I have a basic question about the use of policy functions to set certain precision levels. As stated in the example of your README.md:

    def layer(params, x):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    I am new to JAX, but the first thing I learned is that JAX likes pure functions. Does the use of my_policy violate the pure functions paradigm?

    Should it become:

    def layer(params, x, my_policy):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    The function can be jitted by using partial.

    from functools import partial
    layer_compiled = jit(partial(layer, my_policy=my_policy))
    

    Fundamental question

    A more fundamental question (which I maybe need to ask at JAX ) is: How should functions as input to functions be handled in JAX?

    func_a(x,w,func):
      ...
    
    func_a_compiled = jit(partial(func_a, func=func_b))
    

    In order to jit this, I came up with the partial solution above. I assume your use of my_policy is valid, as you probably have more experience with JAX. But that creates some magic, which is undesirable for my use case. Is the jit(partial()) solution valid or is there a better way to handle functions as input to functions?

    Have a nice day! J

    opened by JSchuurmans 2
  •  jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    When trying to build my own wheel from the tar-ball of jmp-0.0.2 that has been uploaded to PyPI I get the following error:

    Collecting jmp==0.0.2
      Downloading jmp-0.0.2.tar.gz (13 kB)
        Running command python setup.py egg_info
        Traceback (most recent call last):
          File "<string>", line 1, in <module>
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 55, in <module>
            install_requires=_parse_requirements('requirements.txt'),
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 32, in _parse_requirements
            with open(requirements_txt_path) as fp:
        FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
    WARNING: Discarding https://files.pythonhosted.org/packages/7c/ba/a6bfcaeedca8551e2fb4054d1fd061a0dd97d26dd44002b3e92d13b51877/jmp-0.0.2.tar.gz#sha256=fdb5cec0d10aab4116c2770f24b2adf4f503fcfbb96ce8ef583e1879bdbf1b9b (from https://pypi.org/simple/jmp/). Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement jmp==0.0.2 (from versions: 0.0.1, 0.0.2)
    ERROR: No matching distribution found for jmp==0.0.2
    

    After downloading and extracting jmp-0.0.2.tar.gz from PyPI and looking at setup.py, I see that it contains lines that reference requirements.txt:

    install_requires=_parse_requirements('requirements.txt')
    

    however that file is not part of the source distribution and my build fails.

    Oliver

    P.S. Yes I know that there is a wheel for jmp-0.0.2 on PyPI, which I ended up using, but I'm using our Wheels_builder script that will recursively build wheels for dependencies for our systems. My point is that the source distribution of jmp 0.0.2 is incomplete as it lacks the information that is contained in the requirements file.

    opened by ostueker 1
  • Casting numpy array

    Casting numpy array

    I'll first start by saying that I've just starting with JAX, so I might be doing something wrong.

    When I run the following code:

    precision = jmp.get_policy('params=float32,compute=float16,output=float16')
    some_input = np.arange(15).reshape((5, 3))
    @jax.jit
    def some_function(some_input):
        some_input = precision.cast_to_compute(some_input)
        print(some_input.dtype)
        # prints float32
       return a_float16_compute_model(some_input) # fails
    

    It seems that at least on the first (tracing) run, the numpy array doesn't get cast to float16, maybe because it is being treated as a tree, here

    https://github.com/deepmind/jmp/blob/4b94370b8de29b79d6f840b09d1990b91c1afddd/jmp/_src/policy.py#L26

    Surprisingly, if I run

    some_input = np.arange(15).reshape((5, 3)).astype(precision.compute_dtype)
    print(some_input.dtype)
    
    #prints float16
    

    the cast succeeds. I think that the expected behavior is that precision.cast_to_compute(some_input) should return a numpy array with compute_dtype, but I might be missing something.

    opened by yardenas 1
  • [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    An upcoming change to JAX adds a @jit decorator around a number of array operators, in this case division (/). A side effect of that change is that large Python integer constants that overflow an int32 or int64 type may produce an error. The workaround is either to explicitly cast the large constants to a specific type (e.g., np.float64), or in this case we can just do the math in question in classic NumPy since it is computing test expectations.

    cla: yes 
    opened by copybara-service[bot] 0
  • Bump version to `0.0.3.dev`.

    Bump version to `0.0.3.dev`.

    Bump version to 0.0.3.dev.

    This is so if users install from GitHub we can see they are not using a stable version in bug reports:

    $ pip install git+https://github.com/deepmind/jmp
    $ python -c 'import jmp ; print(jmp.__version__)'
    0.0.3.dev
    
    cla: yes 
    opened by copybara-service[bot] 0
  • Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    We are not using 0.0.1 because this version was already used by the previous owner of the pypi package (https://pypi.org/project/jmp/0.0.1/). As such our releases will start at 0.0.2.

    cla: yes 
    opened by copybara-service[bot] 0
  • Replace jax.lax.select with jnp.where

    Replace jax.lax.select with jnp.where

    Thanks for the awesome work!

    This PR fixes an issue where jax.lax.select complains about dtypes not being equal when adjusting the DynamicLossScale.

    The exception's stack trace ends with:

    File ".../jmp/_src/loss_scale.py", line 147, in adjust
        loss_scale = jax.lax.select(
    TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
    

    My code looks similar to

    scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
    ...
    gradients, scale = gradient_fn(..., scale)
    gradients = scale.unscale(gradients)
    gradients_finite = jmp.all_finite(gradients)
    scale = scale.adjust(gradients_finite)  # This line throws the exception
    ...
    
    opened by nlsfnr 6
Releases(v0.0.2)
  • v0.0.2(Apr 15, 2021)

    Initial release of JMP.

    Changelog:

    • Add jmp.Policy abstraction and jmp.get_policy(..) factory.
    • Add jmp.LossScale and three implementations thereof (noop, static and dynamic).
    • Add various utilities (jmp.all_finite) to support common tasks in mixed precision codebases.
    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
Hypernets: A General Automated Machine Learning framework to simplify the development of End-to-end AutoML toolkits in specific domains.

A General Automated Machine Learning framework to simplify the development of End-to-end AutoML toolkits in specific domains.

DataCanvas 216 Dec 23, 2022
This project has Classification and Clustering done Via kNN and K-Means respectfully

This project has Classification and Clustering done Via kNN and K-Means respectfully. It later tests its efficiency via F1/accuracy/recall/precision for kNN and Davies-Bouldin Index for Clustering. T

Mohammad Ali Mustafa 0 Jan 20, 2022
Iris-Heroku - Putting a Machine Learning Model into Production with Flask and Heroku

Puesta en Producción de un modelo de aprendizaje automático con Flask y Heroku L

Jesùs Guillen 1 Jun 03, 2022
flexible time-series processing & feature extraction

A corona statistics and information telegram bot.

PreDiCT.IDLab 206 Dec 28, 2022
A Collection of Conference & School Notes in Machine Learning 🦄📝🎉

Machine Learning Conference & Summer School Notes. 🦄📝🎉

558 Dec 28, 2022
Cool Python features for machine learning that I used to be too afraid to use. Will be updated as I have more time / learn more.

python-is-cool A gentle guide to the Python features that I didn't know existed or was too afraid to use. This will be updated as I learn more and bec

Chip Huyen 3.3k Jan 05, 2023
A repository to work on Machine Learning course. Select an algorithm to classify writer's gender, of Hebrew texts.

MachineLearning A repository to work on Machine Learning course. Select an algorithm to classify writer's gender, of Hebrew texts. Tested algorithms:

Haim Adrian 1 Feb 01, 2022
This repo implements a Topological SLAM: Deep Visual Odometry with Long Term Place Recognition (Loop Closure Detection)

This repo implements a topological SLAM system. Deep Visual Odometry (DF-VO) and Visual Place Recognition are combined to form the topological SLAM system.

Best of Australian Centre for Robotic Vision (ACRV) 32 Jun 23, 2022
Data Efficient Decision Making

Data Efficient Decision Making

Microsoft 197 Jan 06, 2023
A repository for collating all the resources such as articles, blogs, papers, and books related to Bayesian Statistics.

A repository for collating all the resources such as articles, blogs, papers, and books related to Bayesian Statistics.

Aayush Malik 80 Dec 12, 2022
AP1 Transcription Factor Binding Site Prediction

A machine learning project that predicted binding sites of AP1 transcription factor, using ChIP-Seq data and local DNA shape information.

1 Jan 21, 2022
ArviZ is a Python package for exploratory analysis of Bayesian models

ArviZ (pronounced "AR-vees") is a Python package for exploratory analysis of Bayesian models. Includes functions for posterior analysis, data storage, model checking, comparison and diagnostics

ArviZ 1.3k Jan 05, 2023
Implementation of different ML Algorithms from scratch, written in Python 3.x

Implementation of different ML Algorithms from scratch, written in Python 3.x

Gautam J 393 Nov 29, 2022
Library for machine learning stacking generalization.

stacked_generalization Implemented machine learning *stacking technic[1]* as handy library in Python. Feature weighted linear stacking is also availab

114 Jul 19, 2022
An MLOps framework to package, deploy, monitor and manage thousands of production machine learning models

Seldon Core: Blazing Fast, Industry-Ready ML An open source platform to deploy your machine learning models on Kubernetes at massive scale. Overview S

Seldon 3.5k Jan 01, 2023
Pandas DataFrames and Series as Interactive Tables in Jupyter

Pandas DataFrames and Series as Interactive Tables in Jupyter Star Turn pandas DataFrames and Series into interactive datatables in both your notebook

Marc Wouts 364 Jan 04, 2023
dirty_cat is a Python module for machine-learning on dirty categorical variables.

dirty_cat dirty_cat is a Python module for machine-learning on dirty categorical variables.

637 Dec 29, 2022
Predicting Baseball Metric Clusters: Clustering Application in Python Using scikit-learn

Clustering Clustering Application in Python Using scikit-learn This repository contains the prediction of baseball metric clusters using MLB Statcast

Tom Weichle 2 Apr 18, 2022
Apache (Py)Spark type annotations (stub files).

PySpark Stubs A collection of the Apache Spark stub files. These files were generated by stubgen and manually edited to include accurate type hints. T

Maciej 114 Nov 22, 2022
Pydantic based mock data generation

This library offers powerful mock data generation capabilities for pydantic based models. It can also be used with other libraries that use pydantic as a foundation, for example SQLModel, Beanie and

Na'aman Hirschfeld 396 Dec 28, 2022