Implementation of Feedback Transformer in Pytorch

Overview

Feedback Transformer - Pytorch

Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step.

The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work.

Yannic Kilcher video

I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings.

Install

$ pip install feedback-transformer-pytorch

Usage

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,           # number of tokens
    dim = 512,                    # dimension
    depth = 6,                    # depth
    seq_len = 2,                  # the sequence length of each segment or window
    mem_len = 256,                # length of the memory buffer
    dim_head = 64,                # dimension of each head
    heads = 8,                    # number of heads
    attn_dropout = 0.1,           # attention dropout
    ff_dropout = 0.1              # feedforward dropout
).cuda()

x = torch.randint(0, 20000, (2, 64)).cuda()
model(x)  # (2, 64, 20000)

If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on .forward

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 32,
    mem_len = 256
).cuda()

x1 = torch.randint(0, 20000, (2, 32)).cuda()
x2 = torch.randint(0, 20000, (2, 32)).cuda()
x3 = torch.randint(0, 20000, (2, 32)).cuda()

out1, mem1 = model(x1, return_memory = True)
out2, mem2 = model(x2, memory = mem1, return_memory = True)
out3, mem3 = model(x3, memory = mem2, return_memory = True)  # (2, 32, 20000)

Citations

@misc{fan2021addressing,
    title   = {Addressing Some Limitations of Transformers with Feedback Memory}, 
    author  = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar},
    year    = {2021},
    eprint  = {2002.09402},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Should it really be using lower layers output for keys and values?

    Should it really be using lower layers output for keys and values?

    Could you explain the logic of how the key-value pairs are formed at these lines and whether it is necessary?

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/d7d8939910d1491f01a3d93ce81d4663925fb389/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L146-L151

    It looks to me that line 146 transforms the output of the layer below (x) to keys and values, and the following lines combine these keys and values with the memory. I thought that x should only be used for forming the query here, and only the existing memory is used for keys and values.

    opened by tarvaina 6
  • In place operation with gradient

    In place operation with gradient

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L173 I think this is an error.

    opened by hadaev8 4
  • Bug in weighted sum

    Bug in weighted sum

    Bug in https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L264

    Should be layer_weight = rearrange(layer_weight, 'd -> d () () ()')

    opened by Victor0118 1
  • Input/Output dimensions

    Input/Output dimensions

    Hey @lucidrains

    Can I check the dimensions of the input and output, is it (seq_len, dim) -> (? ,dim, tokens)?

    model = FeedbackTransformer(
        num_tokens = 20000,           # number of tokens
        dim = 512,                    # dimension
        depth = 6,                    # depth
        seq_len = 2,                  # the sequence length of each segment or window
        mem_len = 256,                # length of the memory buffer
        dim_head = 64,                # dimension of each head
        heads = 8,                    # number of heads
        attn_dropout = 0.1,           # attention dropout
        ff_dropout = 0.1              # feedforward dropout
    ).cuda()
    
    x = torch.randint(0, 256, (2, 512)).cuda()
    model(x)  # (1, 512, 20000)
    
    opened by iiSeymour 1
  • Non intuitive memory usage with cross attention

    Non intuitive memory usage with cross attention

    Give simple 256 dim and 512 len tensor and memory len 16 feedback transformer uses 3.6gm memory after forward pass. With cross attention on 100 len tensor usage grows to 14gb.

    While parallel version uses 3.1gb and 3.5gb.

    Notebooks for testing https://colab.research.google.com/drive/1dRImydFn3WthOXdLYIvdf5bsqjXcmhC5?usp=sharing https://colab.research.google.com/drive/1n653j4Pz9_U7OukhTlUbomAHMvpPXwx0?usp=sharing

    opened by hadaev8 0
  • I think mask padding value should be False

    I think mask padding value should be False

    Here https://github.com/lucidrains/feedback-transformer-pytorch/blob/with-cross-attention/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L181

    opened by hadaev8 0
  • ETA for the enwiki8 example

    ETA for the enwiki8 example

    Hey @lucidrains,

    Any eta on the example for auto-regressive enwiki8 example? I and others would really appreciate it as always :)

    Also, if you can provide an example for training on custom line-by-line TXT datasets, it would be absolutely fantastic.

    Thank you.

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
An open source bike computer based on Raspberry Pi Zero (W, WH) with GPS and ANT+. Including offline map and navigation.

Pi Zero Bikecomputer An open-source bike computer based on Raspberry Pi Zero (W, WH) with GPS and ANT+ https://github.com/hishizuka/pizero_bikecompute

hishizuka 264 Jan 02, 2023
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 06, 2022
Implementation of Hierarchical Transformer Memory (HTM) for Pytorch

Hierarchical Transformer Memory (HTM) - Pytorch Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a si

Phil Wang 63 Dec 29, 2022
This is the official implementation of Elaborative Rehearsal for Zero-shot Action Recognition (ICCV2021)

Elaborative Rehearsal for Zero-shot Action Recognition This is an official implementation of: Shizhe Chen and Dong Huang, Elaborative Rehearsal for Ze

DeLightCMU 26 Sep 24, 2022
GAN-based Matrix Factorization for Recommender Systems

GAN-based Matrix Factorization for Recommender Systems This repository contains the datasets' splits, the source code of the experiments and their res

Ervin Dervishaj 9 Nov 06, 2022
Code for the paper "Offline Reinforcement Learning as One Big Sequence Modeling Problem"

Trajectory Transformer Code release for Offline Reinforcement Learning as One Big Sequence Modeling Problem. Installation All python dependencies are

Michael Janner 266 Dec 27, 2022
Python scripts for performing stereo depth estimation using the MobileStereoNet model in Tensorflow Lite.

TFLite-MobileStereoNet Python scripts for performing stereo depth estimation using the MobileStereoNet model in Tensorflow Lite. Stereo depth estimati

Ibai Gorordo 4 Feb 14, 2022
This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et al. 2020

README This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et a

Raghav 42 Dec 15, 2022
Free course that takes you from zero to Reinforcement Learning PRO 🦸🏻‍🦸🏽

The Hands-on Reinforcement Learning course 🚀 From zero to HERO 🦸🏻‍🦸🏽 Out of intense complexities, intense simplicities emerge. -- Winston Churchi

Pau Labarta Bajo 260 Dec 28, 2022
A framework for the elicitation, specification, formalization and understanding of requirements.

A framework for the elicitation, specification, formalization and understanding of requirements.

NASA - Software V&V 161 Jan 03, 2023
CCP dataset from Clothing Co-Parsing by Joint Image Segmentation and Labeling

Clothing Co-Parsing (CCP) Dataset Clothing Co-Parsing (CCP) dataset is a new clothing database including elaborately annotated clothing items. 2, 098

Wei Yang 434 Dec 24, 2022
IAUnet: Global Context-Aware Feature Learning for Person Re-Identification

IAUnet This repository contains the code for the paper: IAUnet: Global Context-Aware Feature Learning for Person Re-Identification Ruibing Hou, Bingpe

30 Jul 14, 2022
Official implementation for Multi-Modal Interaction Graph Convolutional Network for Temporal Language Localization in Videos

Multi-modal Interaction Graph Convolutioal Network for Temporal Language Localization in Videos Official implementation for Multi-Modal Interaction Gr

Zongmeng Zhang 15 Oct 18, 2022
To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beginners, intermediates as well as experts

JaxTon 💯 JAX exercises Mission 🚀 To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beg

Rohan Rao 512 Jan 01, 2023
[IJCAI-2021] A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

DataFree A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation" Authors: Gongfa

ZJU-VIPA 47 Jan 09, 2023
Implementation of " SESS: Self-Ensembling Semi-Supervised 3D Object Detection" (CVPR2020 Oral)

SESS: Self-Ensembling Semi-Supervised 3D Object Detection Created by Na Zhao from National University of Singapore Introduction This repository contai

125 Dec 23, 2022
This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting.

GAN Memory for Lifelong learning This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting. Please consider citing our paper

Miaoyun Zhao 43 Dec 27, 2022
[KDD 2021, Research Track] DiffMG: Differentiable Meta Graph Search for Heterogeneous Graph Neural Networks

DiffMG This repository contains the code for our KDD 2021 Research Track paper: DiffMG: Differentiable Meta Graph Search for Heterogeneous Graph Neura

AutoML Research 24 Nov 29, 2022
Dense matching library based on PyTorch

Dense Matching A general dense matching library based on PyTorch. For any questions, issues or recommendations, please contact Prune at

Prune Truong 399 Dec 28, 2022
Generalized and Efficient Blackbox Optimization System.

OpenBox Doc | OpenBox中文文档 OpenBox: Generalized and Efficient Blackbox Optimization System OpenBox is an efficient and generalized blackbox optimizatio

DAIR Lab 238 Dec 29, 2022