Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Overview

Memory Efficient Attention

arXiv PyPI version

This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch.

Implementation is almost same as the one proposed in the paper, with additional masking and adding bias compatibility, batch dimensions support and PyTorch implementation. For computing attention, the proposed method requires only O(sqrt(n)) memory, and the provided functions can be used as a drop-in replacement for attention calculation.

Important Note: This implementation is a trade-off between memory requirements and runtime, so you should adjust key_chunk_size and query_chunk_size parameters to achieve the best configuration for your usecase. Here is a note from the paper's authors:

While a constant chunk size for the queries and a chunk size of sqrt(n) for the keys and values is optimal for memory consumption, the runtime is also affected by the choice of chunk size in practice, which is heavily affected by the choice of hardware. Ultimately, we have to leave this trade-off to the programmer, and expose the chunk sizes as arguments query_chunk_size and key_chunk_size. In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact (on TPUv2), while still providing significant memory savings.

Quick Start

  1. Install the library
# for Jax
pip install memory-efficient-attention[jax]
# for PyTorch
pip install memory-efficient-attention[torch]
# for Running Tests
pip install memory-efficient-attention[testing]
  1. Compute attention with the proper function
0.5 bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100 # Adjust chunk sizes efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)">
import numpy as np
# for PyTorch
from memory_efficient_attention import efficient_dot_product_attention_pt
# or for Jax
from memory_efficient_attention import efficient_dot_product_attention_jax

# Random Data (batch dimensions are not necessary)
b = 8
query = np.random.rand(1, b, 128, 16, 8).astype("float32")
key = np.random.rand(1, b, 128, 16, 8).astype("float32")
value = np.random.rand(1, b, 128, 16, 8).astype("float32")
# optional, for casual tasks, ...
mask = np.random.rand(1, b, 16, 128, 128) > 0.5
bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100

# Adjust chunk sizes        
efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)

Citation

Please cite if this implementation helps your research. You can use the following BibTeX entry:

@misc{memory_efficient_attention,
	title = {Memory Efficient Attention},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/memory-efficient-attention}},
	year = {2021}
}

Also, for the paper:

@misc{rabe2021selfattention,
      title={Self-attention Does Not Need $O(n^2)$ Memory}, 
      author={Markus N. Rabe and Charles Staats},
      year={2021},
      eprint={2112.05682},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
You might also like...
Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory"

memory_efficient_attention.pytorch A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory" (Rabe&Staats'21). def effic

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

Implementation of
Implementation of "Efficient Regional Memory Network for Video Object Segmentation" (Xie et al., CVPR 2021).

RMNet This repository contains the source code for the paper Efficient Regional Memory Network for Video Object Segmentation. Cite this work @inprocee

Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation
Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation

STCN Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [a

Comments
  • feat: output_attentions

    feat: output_attentions

    I'm looking into hacking some of the models in the transformers library to use this library for attention, and I don't see a way to support output_attentions yet. This is a flag passed in transformers, where the attention weights are preserved and returned to the user, if it is set.

    I looked a little at implementing this in the torch backend, and I note the scan() function provides for only a single tensor return value. It seems to me that scan() function would be most clearly replaced by a for loop, but it could also be modified to handle tuples, or return_weights could be handled via accessing nonlocal data in some way instead of returning them through the chunk scanner. I'm also not sure how the output would best be passed to the user.

    Edit: Draft implementation 01/28 at https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:faba6371ac7faaa2040a2c26e15ae7ab87f94ce4 . I ended up extending the scan function for parity between implementations. Edit 2: Turns out it's the postsoftmax attention weights, not the presoftmax attention weights. I've updated this post and the draft implementation for this output: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:return_weights

    opened by xloem 4
  • Provide a flag for the user to receive attention weights

    Provide a flag for the user to receive attention weights

    This is my draft code for #1. I saw this feature in the transformers library and wanted to implement it here.

    I'm curious what you think about this feature and implementation.

    The code is simply slightly instrumented so that the final attention weights can be returned to the user. Tests are augmented to test this use. In utils, the scan function is expanded to handle tuples.

    A change to dynamic_slice crept in from dev, to use slices rather than index_slice. I've retained this change because it looks like it would execute faster to me, but it can be removed.

    Rebased and squashed from 84724e1de4721ea0333d6bdbb91e8bce74fbeac .

    opened by xloem 2
  • Improve performance via batched-matmul and fused multiplies

    Improve performance via batched-matmul and fused multiplies

    Many thanks for providing this reference implementation.

    I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
    https://github.com/Birch-san/diffusers/pull/1/commits/04372140a25d7f53549175f1f196599c3e9bf3a5

    Knowing that computing attention via baddbmm()+bmm() can outperform einsum by 18%: I tried to rewrite the algorithm to use those.

    I compared the speed of my optimized version, against the implementation in this repository. this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).

    compared to the implementation in this repository:
    my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).

    here's my optimized implementation:
    https://github.com/Birch-san/diffusers/pull/1

    batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head].

    code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head] format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head] format favoured by my implementation like so:

    query = query.transpose(1,2).flatten(end_dim=1)
    key = key.transpose(1,2).flatten(end_dim=1)
    value = value.transpose(1,2).flatten(end_dim=1)
    

    the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head] format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head] format like so:

    result.unflatten(0, (-1, attn.heads)).transpose(1,2)
    

    I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.

    EDIT:
    I have now added fast-paths for:

    • skipping kv-chunking when kv_chunk_size >= k_tokens
      • this turns the algorithm into "attention slicing"
    • skipping q-chunking when q_chunk_size >= q_tokens
    • skipping all chunking when the kv_chunk_size >= k_tokens and q_chunk_size >= q_tokens
    • skipping all chunking when the [email protected] matmul requires fewer bytes than a user-provided threshold
    opened by Birch-san 1
Releases(0.1.3)
  • 0.1.2(Mar 7, 2022)

    What's Changed

    This update fixes torch device handling issues in code. GPU and other kinds of tensors can be used safely.

    • Update utils.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5
    • Update attention_torch.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/6

    New Contributors

    • @yhgon made their first contribution in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1.0...0.1.2

    Source code(tar.gz)
    Source code(zip)
  • 0.1.1.0(Feb 3, 2022)

    Added mask, bias calculation functions for custom and memory efficient chunks computation. So now sublinear memory computation mask, bias are possible.

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1...0.1.1.0

    Source code(tar.gz)
    Source code(zip)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling" (ICCV 2021)

Manifold Matching via Deep Metric Learning for Generative Modeling A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generat

69 Dec 10, 2022
Fast and Simple Neural Vocoder, the Multiband RNNMS

Multiband RNN_MS Fast and Simple vocoder, Multiband RNN_MS. Demo Quick training How to Use System Details Results References Demo ToDO: Link super gre

tarepan 5 Jan 11, 2022
Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction".

TGIN Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction". Files in the folder dataset/ electr

Alibaba 21 Dec 21, 2022
Weakly Supervised Learning of Rigid 3D Scene Flow

Weakly Supervised Learning of Rigid 3D Scene Flow This repository provides code and data to train and evaluate a weakly supervised method for rigid 3D

Zan Gojcic 124 Dec 27, 2022
Graph Convolutional Neural Networks with Data-driven Graph Filter (GCNN-DDGF)

Graph Convolutional Gated Recurrent Neural Network (GCGRNN) Improved from Graph Convolutional Neural Networks with Data-driven Graph Filter (GCNN-DDGF

Lei Lin 21 Dec 18, 2022
PSGAN running with ncnn⚡妆容迁移/仿妆⚡Imitation Makeup/Makeup Transfer⚡

PSGAN running with ncnn⚡妆容迁移/仿妆⚡Imitation Makeup/Makeup Transfer⚡

WuJinxuan 144 Dec 26, 2022
A Python library for differentiable optimal control on accelerators.

A Python library for differentiable optimal control on accelerators.

Google 80 Dec 21, 2022
CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

Frederick Wang 3 Apr 26, 2022
Contrastive Feature Loss for Image Prediction

Contrastive Feature Loss for Image Prediction We provide a PyTorch implementation of our contrastive feature loss presented in: Contrastive Feature Lo

Alex Andonian 44 Oct 05, 2022
Phonetic PosteriorGram (PPG)-Based Voice Conversion (VC)

ppg-vc Phonetic PosteriorGram (PPG)-Based Voice Conversion (VC) This repo implements different kinds of PPG-based VC models. Pretrained models. More m

Liu Songxiang 227 Dec 28, 2022
A collection of Reinforcement Learning algorithms from Sutton and Barto's book and other research papers implemented in Python.

Reinforcement-Learning-Notebooks A collection of Reinforcement Learning algorithms from Sutton and Barto's book and other research papers implemented

Pulkit Khandelwal 1k Dec 28, 2022
Zero-shot Learning by Generating Task-specific Adapters

Code for "Zero-shot Learning by Generating Task-specific Adapters" This is the repository containing code for "Zero-shot Learning by Generating Task-s

INK Lab @ USC 11 Dec 17, 2021
Single/multi view image(s) to voxel reconstruction using a recurrent neural network

3D-R2N2: 3D Recurrent Reconstruction Neural Network This repository contains the source codes for the paper Choy et al., 3D-R2N2: A Unified Approach f

Chris Choy 1.2k Dec 27, 2022
Powerful unsupervised domain adaptation method for dense retrieval.

Powerful unsupervised domain adaptation method for dense retrieval

Ubiquitous Knowledge Processing Lab 191 Dec 28, 2022
Code & Data for Enhancing Photorealism Enhancement

Enhancing Photorealism Enhancement Stephan R. Richter, Hassan Abu AlHaija, Vladlen Koltun Paper | Website (with side-by-side comparisons) | Video (Pap

Intelligent Systems Lab Org 1.1k Dec 31, 2022
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

Adelaide Intelligent Machines (AIM) Group 7 Sep 12, 2022
Differentiable Wavetable Synthesis

Differentiable Wavetable Synthesis

4 Feb 11, 2022
Pytorch implementation of "Geometrically Adaptive Dictionary Attack on Face Recognition" (WACV 2022)

Geometrically Adaptive Dictionary Attack on Face Recognition This is the Pytorch code of our paper "Geometrically Adaptive Dictionary Attack on Face R

6 Nov 21, 2022
This repo provides the source code & data of our paper "GreaseLM: Graph REASoning Enhanced Language Models"

GreaseLM: Graph REASoning Enhanced Language Models This repo provides the source code & data of our paper "GreaseLM: Graph REASoning Enhanced Language

137 Jan 02, 2023
Annotated, understandable, and visually interpretable PyTorch implementations of: VAE, BIRVAE, NSGAN, MMGAN, WGAN, WGANGP, LSGAN, DRAGAN, BEGAN, RaGAN, InfoGAN, fGAN, FisherGAN

Overview PyTorch 0.4.1 | Python 3.6.5 Annotated implementations with comparative introductions for minimax, non-saturating, wasserstein, wasserstein g

Shayne O'Brien 471 Dec 16, 2022