Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

Overview

FLASH - Pytorch

Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time

Install

$ pip install FLASH-pytorch

Usage

The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.

It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.

import torch
from flash_pytorch import GAU

gau = GAU(
    dim = 512,
    query_key_dim = 128,     # query / key dimension
    causal = True,           # autoregressive or not
    expansion_factor = 2,    # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)

The authors then combine GAU with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.

This combination of the quadratic gated attention unit with grouped linear attention they named FLASH

You can also use this quite easily

import torch
from flash_pytorch import FLASH

flash = FLASH(
    dim = 512,
    group_size = 256,             # group size
    causal = True,                # autoregressive or not
    query_key_dim = 128,          # query / key dimension
    expansion_factor = 2.         # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1111, 512)     # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)

Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).

import torch
from flash_pytorch import FLASHTransformer

model = FLASHTransformer(
    num_tokens = 20000,          # number of tokens
    dim = 512,                   # model dimension
    depth = 12,                  # depth
    causal = True,               # autoregressive or not
    group_size = 256,            # size of the groups
    query_key_dim = 128,         # dimension of queries / keys
    expansion_factor = 2.,       # hidden dimension = dim * expansion_factor
    norm_type = 'scalenorm',     # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
    shift_tokens = True          # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)

Test on Autoregressive Enwik8

$ python train.py

Citations

@article{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.10447}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
Comments
  • einsum operation in Linear Attention Part

    einsum operation in Linear Attention Part

    Hi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343

    lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
    lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
    

    the lin_kv is three-dim (bde) And the code in the paper is

    lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask) 
    linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)
    

    the lin_kv is four-dim (bgke) It seems that the two ways are not equivalent.

    Looking forward to your reply. Best,

    opened by ShomyLiu 5
  • mask error

    mask error

    x = torch.randint(0, 20000, (1, 1024))
    mask = x.ne(0)
    logits = model(x, mask=mask)
    

    RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 2

    opened by keyunluo 1
  • Speed on TPU

    Speed on TPU

    Hi, Thanks for the code! I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.

    opened by magicknight 0
  • About the

    About the "shift_tokens"

    Thank you for your amazing code.

    In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following: if self.shift_tokens: x_shift, x_pass = normed_x.chunk(2, dim = -1) x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) normed_x = torch.cat((x_shift, x_pass), dim = -1)

    Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.

    In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.

    May I know why it works?

    Kang

    opened by kangzhao2 1
  • Cross-Attention?

    Cross-Attention?

    Hi, @lucidrains. Thank you for sharing this excellent implementation with us all! Do you have any thoughts as to what changes would need to be made to make cross-attention possible with your FLASH model?

    opened by amorehead 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
code for Grapadora research paper experimentation

Road feature embedding selection method Code for research paper experimentation Abstract Traffic forecasting models rely on data that needs to be sens

Eric López Manibardo 0 May 26, 2022
Unsupervised Image Generation with Infinite Generative Adversarial Networks

Unsupervised Image Generation with Infinite Generative Adversarial Networks Here is the implementation of MICGANs using DCGAN architecture on MNIST da

16 Dec 24, 2021
tensorflow implementation of 'YOLO : Real-Time Object Detection'

YOLO_tensorflow (Version 0.3, Last updated :2017.02.21) 1.Introduction This is tensorflow implementation of the YOLO:Real-Time Object Detection It can

Jinyoung Choi 1.7k Nov 21, 2022
[CVPR 2020] 3D Photography using Context-aware Layered Depth Inpainting

[CVPR 2020] 3D Photography using Context-aware Layered Depth Inpainting [Paper] [Project Website] [Google Colab] We propose a method for converting a

Virginia Tech Vision and Learning Lab 6.2k Jan 01, 2023
quantize aware training package for NCNN on pytorch

ncnnqat ncnnqat is a quantize aware training package for NCNN on pytorch. Table of Contents ncnnqat Table of Contents Installation Usage Code Examples

62 Nov 23, 2022
MutualGuide is a compact object detector specially designed for embedded devices

Introduction MutualGuide is a compact object detector specially designed for embedded devices. Comparing to existing detectors, this repo contains two

ZHANG Heng 103 Dec 13, 2022
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
GLaRA: Graph-based Labeling Rule Augmentation for Weakly Supervised Named Entity Recognition

GLaRA: Graph-based Labeling Rule Augmentation for Weakly Supervised Named Entity Recognition

Xinyan Zhao 29 Dec 26, 2022
LibMTL: A PyTorch Library for Multi-Task Learning

LibMTL LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and AP

765 Jan 06, 2023
This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network.

GPRGNN This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network. Hidden state feature extraction i

Jianhao 92 Jan 03, 2023
Implements a fake news detection program using classifiers.

Fake news detection Implements a fake news detection program using classifiers for Data Mining course at UoA. Description The project is the categoriz

Apostolos Karvelas 1 Jan 09, 2022
Spatial Intention Maps for Multi-Agent Mobile Manipulation (ICRA 2021)

spatial-intention-maps This code release accompanies the following paper: Spatial Intention Maps for Multi-Agent Mobile Manipulation Jimmy Wu, Xingyua

Jimmy Wu 70 Jan 02, 2023
Establishing Strong Baselines for TripClick Health Retrieval; ECIR 2022

TripClick Baselines with Improved Training Data Welcome 🙌 to the hub-repo of our paper: Establishing Strong Baselines for TripClick Health Retrieval

Sebastian Hofstätter 3 Nov 03, 2022
PyTorch implementation of the REMIND method from our ECCV-2020 paper "REMIND Your Neural Network to Prevent Catastrophic Forgetting"

REMIND Your Neural Network to Prevent Catastrophic Forgetting This is a PyTorch implementation of the REMIND algorithm from our ECCV-2020 paper. An ar

Tyler Hayes 72 Nov 27, 2022
Repo for the Video Person Clustering dataset, and code for the associated paper

Video Person Clustering Repo for the Video Person Clustering dataset, and code for the associated paper. This reporsitory contains the Video Person Cl

Andrew Brown 47 Nov 02, 2022
Pytorch implementation of few-shot semantic image synthesis

Few-shot Semantic Image Synthesis Using StyleGAN Prior Our method can synthesize photorealistic images from dense or sparse semantic annotations using

40 Sep 26, 2022
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

Han Xu 129 Dec 11, 2022
FinRL­-Meta: A Universe for Data­-Driven Financial Reinforcement Learning. 🔥

FinRL-Meta: A Universe of Market Environments. FinRL-Meta is a universe of market environments for data-driven financial reinforcement learning. Users

AI4Finance Foundation 543 Jan 08, 2023
Pytorch Implementation for NeurIPS (oral) paper: Pixel Level Cycle Association: A New Perspective for Domain Adaptive Semantic Segmentation

Pixel-Level Cycle Association This is the Pytorch implementation of our NeurIPS 2020 Oral paper Pixel-Level Cycle Association: A New Perspective for D

87 Oct 19, 2022