Implementation of Feedback Transformer in Pytorch

Overview

Feedback Transformer - Pytorch

Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step.

The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work.

Yannic Kilcher video

I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings.

Install

$ pip install feedback-transformer-pytorch

Usage

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,           # number of tokens
    dim = 512,                    # dimension
    depth = 6,                    # depth
    seq_len = 2,                  # the sequence length of each segment or window
    mem_len = 256,                # length of the memory buffer
    dim_head = 64,                # dimension of each head
    heads = 8,                    # number of heads
    attn_dropout = 0.1,           # attention dropout
    ff_dropout = 0.1              # feedforward dropout
).cuda()

x = torch.randint(0, 20000, (2, 64)).cuda()
model(x)  # (2, 64, 20000)

If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on .forward

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 32,
    mem_len = 256
).cuda()

x1 = torch.randint(0, 20000, (2, 32)).cuda()
x2 = torch.randint(0, 20000, (2, 32)).cuda()
x3 = torch.randint(0, 20000, (2, 32)).cuda()

out1, mem1 = model(x1, return_memory = True)
out2, mem2 = model(x2, memory = mem1, return_memory = True)
out3, mem3 = model(x3, memory = mem2, return_memory = True)  # (2, 32, 20000)

Citations

@misc{fan2021addressing,
    title   = {Addressing Some Limitations of Transformers with Feedback Memory}, 
    author  = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar},
    year    = {2021},
    eprint  = {2002.09402},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Should it really be using lower layers output for keys and values?

    Should it really be using lower layers output for keys and values?

    Could you explain the logic of how the key-value pairs are formed at these lines and whether it is necessary?

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/d7d8939910d1491f01a3d93ce81d4663925fb389/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L146-L151

    It looks to me that line 146 transforms the output of the layer below (x) to keys and values, and the following lines combine these keys and values with the memory. I thought that x should only be used for forming the query here, and only the existing memory is used for keys and values.

    opened by tarvaina 6
  • In place operation with gradient

    In place operation with gradient

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L173 I think this is an error.

    opened by hadaev8 4
  • Bug in weighted sum

    Bug in weighted sum

    Bug in https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L264

    Should be layer_weight = rearrange(layer_weight, 'd -> d () () ()')

    opened by Victor0118 1
  • Input/Output dimensions

    Input/Output dimensions

    Hey @lucidrains

    Can I check the dimensions of the input and output, is it (seq_len, dim) -> (? ,dim, tokens)?

    model = FeedbackTransformer(
        num_tokens = 20000,           # number of tokens
        dim = 512,                    # dimension
        depth = 6,                    # depth
        seq_len = 2,                  # the sequence length of each segment or window
        mem_len = 256,                # length of the memory buffer
        dim_head = 64,                # dimension of each head
        heads = 8,                    # number of heads
        attn_dropout = 0.1,           # attention dropout
        ff_dropout = 0.1              # feedforward dropout
    ).cuda()
    
    x = torch.randint(0, 256, (2, 512)).cuda()
    model(x)  # (1, 512, 20000)
    
    opened by iiSeymour 1
  • Non intuitive memory usage with cross attention

    Non intuitive memory usage with cross attention

    Give simple 256 dim and 512 len tensor and memory len 16 feedback transformer uses 3.6gm memory after forward pass. With cross attention on 100 len tensor usage grows to 14gb.

    While parallel version uses 3.1gb and 3.5gb.

    Notebooks for testing https://colab.research.google.com/drive/1dRImydFn3WthOXdLYIvdf5bsqjXcmhC5?usp=sharing https://colab.research.google.com/drive/1n653j4Pz9_U7OukhTlUbomAHMvpPXwx0?usp=sharing

    opened by hadaev8 0
  • I think mask padding value should be False

    I think mask padding value should be False

    Here https://github.com/lucidrains/feedback-transformer-pytorch/blob/with-cross-attention/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L181

    opened by hadaev8 0
  • ETA for the enwiki8 example

    ETA for the enwiki8 example

    Hey @lucidrains,

    Any eta on the example for auto-regressive enwiki8 example? I and others would really appreciate it as always :)

    Also, if you can provide an example for training on custom line-by-line TXT datasets, it would be absolutely fantastic.

    Thank you.

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
VGGVox models for Speaker Identification and Verification trained on the VoxCeleb (1 & 2) datasets

VGGVox models for speaker identification and verification This directory contains code to import and evaluate the speaker identification and verificat

338 Dec 27, 2022
Conversion between units used in magnetism

convmag Conversion between various units used in magnetism The conversions between base units available are: T - G : 1e4

0 Jul 15, 2021
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
[3DV 2021] Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation

Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation This is the official implementation for the method described in Ch

Jiaxing Yan 27 Dec 30, 2022
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation".

Dual-Contrastive-Learning A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation". Y

hoshi-hiyouga 85 Dec 26, 2022
A PyTorch Implementation of ViT (Vision Transformer)

ViT - Vision Transformer This is an implementation of ViT - Vision Transformer by Google Research Team through the paper "An Image is Worth 16x16 Word

Quan Nguyen 7 May 11, 2022
OpenVINO黑客松比赛项目

Window_Guard OpenVINO黑客松比赛项目 英文名称:Window_Guard 中文名称:窗口卫士 硬件 树莓派4B 8G版本 一个磁石开关 USB摄像头(MP4视频文件也可以) 软件(库) OpenVINO RPi 使用方法 本项目使用的OPenVINO是是2021.3版本,并使用了

Tango 6 Jul 04, 2021
Official repository for the paper "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks"

Easy-To-Hard The official repository for the paper "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks". Gett

Avi Schwarzschild 52 Sep 08, 2022
Explicable Reward Design for Reinforcement Learning Agents [NeurIPS'21]

Explicable Reward Design for Reinforcement Learning Agents [NeurIPS'21]

3 May 12, 2022
The VeriNet toolkit for verification of neural networks

VeriNet The VeriNet toolkit is a state-of-the-art sound and complete symbolic interval propagation based toolkit for verification of neural networks.

9 Dec 21, 2022
Bayesian Optimization Library for Medical Image Segmentation.

bayesmedaug: Bayesian Optimization Library for Medical Image Segmentation. bayesmedaug optimizes your data augmentation hyperparameters for medical im

Şafak Bilici 7 Feb 10, 2022
The "breathing k-means" algorithm with datasets and example notebooks

The Breathing K-Means Algorithm (with examples) The Breathing K-Means is an approximation algorithm for the k-means problem that (on average) is bette

Bernd Fritzke 75 Nov 17, 2022
A 3D sparse LBM solver implemented using Taichi

taichi_LBM3D Background Taichi_LBM3D is a 3D lattice Boltzmann solver with Multi-Relaxation-Time collision scheme and sparse storage structure impleme

Jianhui Yang 121 Jan 06, 2023
Fast and customizable reconnaissance workflow tool based on simple YAML based DSL.

Fast and customizable reconnaissance workflow tool based on simple YAML based DSL, with support of notifications and distributed workload of that work

Américo Júnior 3 Mar 11, 2022
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index.

TechSEO Crawler Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index. Play with the r

JR Oakes 57 Nov 24, 2022
JittorVis - Visual understanding of deep learning models

JittorVis: Visual understanding of deep learning model JittorVis is an open-source library for understanding the inner workings of Jittor models by vi

thu-vis 182 Jan 06, 2023
Python library containing BART query generation and BERT-based Siamese models for neural retrieval.

Neural Retrieval Embedding-based Zero-shot Retrieval through Query Generation leverages query synthesis over large corpuses of unlabeled text (such as

Amazon Web Services - Labs 35 Apr 14, 2022
Implemented fully documented Particle Swarm Optimization algorithm (basic model with few advanced features) using Python programming language

Implemented fully documented Particle Swarm Optimization (PSO) algorithm in Python which includes a basic model along with few advanced features such as updating inertia weight, cognitive, social lea

9 Nov 29, 2022