Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

Overview

FLASH - Pytorch

Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time

Install

$ pip install FLASH-pytorch

Usage

The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.

It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.

import torch
from flash_pytorch import GAU

gau = GAU(
    dim = 512,
    query_key_dim = 128,     # query / key dimension
    causal = True,           # autoregressive or not
    expansion_factor = 2,    # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)

The authors then combine GAU with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.

This combination of the quadratic gated attention unit with grouped linear attention they named FLASH

You can also use this quite easily

import torch
from flash_pytorch import FLASH

flash = FLASH(
    dim = 512,
    group_size = 256,             # group size
    causal = True,                # autoregressive or not
    query_key_dim = 128,          # query / key dimension
    expansion_factor = 2.         # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1111, 512)     # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)

Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).

import torch
from flash_pytorch import FLASHTransformer

model = FLASHTransformer(
    num_tokens = 20000,          # number of tokens
    dim = 512,                   # model dimension
    depth = 12,                  # depth
    causal = True,               # autoregressive or not
    group_size = 256,            # size of the groups
    query_key_dim = 128,         # dimension of queries / keys
    expansion_factor = 2.,       # hidden dimension = dim * expansion_factor
    norm_type = 'scalenorm',     # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
    shift_tokens = True          # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)

Test on Autoregressive Enwik8

$ python train.py

Citations

@article{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.10447}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
Comments
  • einsum operation in Linear Attention Part

    einsum operation in Linear Attention Part

    Hi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343

    lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
    lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
    

    the lin_kv is three-dim (bde) And the code in the paper is

    lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask) 
    linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)
    

    the lin_kv is four-dim (bgke) It seems that the two ways are not equivalent.

    Looking forward to your reply. Best,

    opened by ShomyLiu 5
  • mask error

    mask error

    x = torch.randint(0, 20000, (1, 1024))
    mask = x.ne(0)
    logits = model(x, mask=mask)
    

    RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 2

    opened by keyunluo 1
  • Speed on TPU

    Speed on TPU

    Hi, Thanks for the code! I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.

    opened by magicknight 0
  • About the

    About the "shift_tokens"

    Thank you for your amazing code.

    In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following: if self.shift_tokens: x_shift, x_pass = normed_x.chunk(2, dim = -1) x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) normed_x = torch.cat((x_shift, x_pass), dim = -1)

    Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.

    In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.

    May I know why it works?

    Kang

    opened by kangzhao2 1
  • Cross-Attention?

    Cross-Attention?

    Hi, @lucidrains. Thank you for sharing this excellent implementation with us all! Do you have any thoughts as to what changes would need to be made to make cross-attention possible with your FLASH model?

    opened by amorehead 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Implements a fake news detection program using classifiers.

Fake news detection Implements a fake news detection program using classifiers for Data Mining course at UoA. Description The project is the categoriz

Apostolos Karvelas 1 Jan 09, 2022
Addon and nodes for working with structural biology and molecular data in Blender.

Molecular Nodes 🧬 🔬 💻 Buy Me a Coffee to Keep Development Going! Join a Community of Blender SciVis People! What is Molecular Nodes? Molecular Node

Brady Johnston 456 Jan 08, 2023
Pyeventbus: a publish/subscribe event bus

pyeventbus pyeventbus is a publish/subscribe event bus for Python 2.7. simplifies the communication between python classes decouples event senders and

15 Apr 21, 2022
Split Variational AutoEncoder

Split-VAE Split Variational AutoEncoder Introduction This repository contains and implemementation of a Split Variational AutoEncoder (SVAE). In a SVA

Andrea Asperti 2 Sep 02, 2022
Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.

Tensor2Tensor Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and ac

12.9k Jan 09, 2023
Code for AutoNL on ImageNet (CVPR2020)

Neural Architecture Search for Lightweight Non-Local Networks This repository contains the code for CVPR 2020 paper Neural Architecture Search for Lig

Yingwei Li 104 Aug 31, 2022
Official repo for BMVC2021 paper ASFormer: Transformer for Action Segmentation

ASFormer: Transformer for Action Segmentation This repo provides training & inference code for BMVC 2021 paper: ASFormer: Transformer for Action Segme

42 Dec 23, 2022
https://sites.google.com/cornell.edu/recsys2021tutorial

Counterfactual Learning and Evaluation for Recommender Systems (RecSys'21 Tutorial) Materials for "Counterfactual Learning and Evaluation for Recommen

yuta-saito 45 Nov 10, 2022
The first dataset on shadow generation for the foreground object in real-world scenes.

Object-Shadow-Generation-Dataset-DESOBA Object Shadow Generation is to deal with the shadow inconsistency between the foreground object and the backgr

BCMI 105 Dec 30, 2022
TCNN Temporal convolutional neural network for real-time speech enhancement in the time domain

TCNN Pandey A, Wang D L. TCNN: Temporal convolutional neural network for real-time speech enhancement in the time domain[C]//ICASSP 2019-2019 IEEE Int

凌逆战 16 Dec 30, 2022
PyTorch implementation of GLOM

GLOM PyTorch implementation of GLOM, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attent

Yeonwoo Sung 20 Aug 17, 2022
Python package for downloading ECMWF reanalysis data and converting it into a time series format.

ecmwf_models Readers and converters for data from the ECMWF reanalysis models. Written in Python. Works great in combination with pytesmo. Citation If

TU Wien - Department of Geodesy and Geoinformation 31 Dec 26, 2022
Towards the D-Optimal Online Experiment Design for Recommender Selection (KDD 2021)

Towards the D-Optimal Online Experiment Design for Recommender Selection (KDD 2021) Contact 0 Jan 11, 2022

FreeSOLO for unsupervised instance segmentation, CVPR 2022

FreeSOLO: Learning to Segment Objects without Annotations This project hosts the code for implementing the FreeSOLO algorithm for unsupervised instanc

NVIDIA Research Projects 253 Jan 02, 2023
Official implementation of TMANet.

Temporal Memory Attention for Video Semantic Segmentation, arxiv Introduction We propose a Temporal Memory Attention Network (TMANet) to adaptively in

wanghao 94 Dec 02, 2022
Codecov coverage standard for Python

Python-Standard Last Updated: 01/07/22 00:09:25 What is this? This is a Python application, with basic unit tests, for which coverage is uploaded to C

Codecov 10 Nov 04, 2022
An example of time series augmentation methods with Keras

Time Series Augmentation This is a collection of time series data augmentation methods and an example use using Keras. News 2020/04/16: Repository Cre

九州大学 ヒューマンインタフェース研究室 229 Jan 02, 2023
Dynamics-aware Adversarial Attack of 3D Sparse Convolution Network

Leaded Gradient Method (LGM) This repository contains the PyTorch implementation for paper Dynamics-aware Adversarial Attack of 3D Sparse Convolution

An Tao 2 Oct 18, 2022
Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2

CoaDTI Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2 Abstract Environment The test was conducted i

Layne_Huang 7 Nov 14, 2022
Weakly Supervised Scene Text Detection using Deep Reinforcement Learning

Weakly Supervised Scene Text Detection using Deep Reinforcement Learning This repository contains the setup for all experiments performed in our Paper

Emanuel Metzenthin 3 Dec 16, 2022