Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

Overview

Memformer - Pytorch

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attention, learned efficiently through Memory-Replay BackPropagation (MRBP) through time.

Install

$ pip install memformer

Usage

Full encoder / decoder, as in the paper

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    num_memory_slots = 128
)

src_seg_1 = torch.randint(0, 256, (1, 1024))
src_seg_2 = torch.randint(0, 256, (1, 1024))
src_seg_3 = torch.randint(0, 256, (1, 1024))

tgt = torch.randint(0, 256, (1, 1024))

enc_out1, mems1,    _ = model(src_seg_1) # (1, 1024, 512), (1, 128, 512), _
enc_out2, mems2,    _ = model(src_seg_2, mems = mems1)
enc_out3, mems3, loss = model(src_seg_3, tgt, mems = mems2)

loss.backward()

Encoder only

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_heads = 8,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    num_memory_slots = 128,
    num_mem_updates = 2,
    encoder_only = True       # only use encoder, in which output is encoded output
)

src1 = torch.randint(0, 256, (1, 1024))
src2 = torch.randint(0, 256, (1, 1024))

enc1, mems1 = model(src1) # (1, 1024, 512), (1, 128, 512)
enc2, mems2 = model(src2, mems = mems1)

Memory Replay Back-Propagation

import torch
from memformer import Memformer, memory_replay_backprop

model = Memformer(
    dim = 512,
    num_memory_slots = 128,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_max_seq_len = 1024
).cuda()

seq = torch.randint(0, 256, (1, 8192)).cuda()
seq_mask = torch.ones_like(seq).bool().cuda()

tgt = torch.randint(0, 256, (1, 512)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

# will automatically split the source sequence to 8 segments
memory_replay_backprop(
    model,
    src = seq,
    tgt = tgt,
    src_mask = seq_mask,
    tgt_mask = tgt_mask
)

Citations

@inproceedings{
    anonymous2021memformer,
    title={Memformer: The Memory-Augmented Transformer},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=_adSMszz_g9},
    note={under review}
}
You might also like...
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Styled Augmented Translation
Styled Augmented Translation

SAT Style Augmented Translation Introduction By collecting high-quality data, we were able to train a model that outperforms Google Translate on 6 dif

TANL: Structured Prediction as Translation between Augmented Natural Languages

TANL: Structured Prediction as Translation between Augmented Natural Languages Code for the paper "Structured Prediction as Translation between Augmen

A neuroanatomy-based augmented reality experience powered by computer vision. Features 3D visuals of the Atlas Brain Map slices.

Brain Augmented Reality (AR) A neuroanatomy-based augmented reality experience powered by computer vision that features 3D visuals of the Atlas Brain

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)
Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments [Project website] [Paper] This project is a PyTorch

A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)
A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)

Library | Paper | Slack We released two versions of OAG-BERT in CogDL package. OAG-BERT is a heterogeneous entity-augmented academic language model wh

DrQ-v2: Improved Data-Augmented Reinforcement Learning
DrQ-v2: Improved Data-Augmented Reinforcement Learning

DrQ-v2: Improved Data-Augmented RL Agent Method DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ,

[EMNLP 2021] Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training

RoSTER The source code used for Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training, p

 RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering
RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering

RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering Authors: Xi Ye, Semih Yavuz, Kazuma Hashimoto, Yingbo Zhou and

Comments
  • WIP - MemformerEncoder

    WIP - MemformerEncoder

    I´m always trying all your awesome work on transformers. My problem is NER on very large texts, with few examples.

    Memformer is the first one so far to converge faster and wield better accuracy than RNN encoders as LSTM, SRU and IndRNN It is ridiculously better than everything else I tested, congratulations @lucidrains 🥳

    I need to use the transformer as a Encoder in my pipeline, to feed a CRF layer. So I modified the code to accept an already embedded input, and to only do the Encode step.

    TODO:

    • [ ] Support Mask
    • [ ] Re-utilize code with Memformer class

    Is this within the scope of the project?

    opened by bratao 10
  • ETA on complete examples

    ETA on complete examples

    @lucidrains As I asked about the feedback-transformer, I was also wondering about this memformer implementation as I would love to try it. Any eta on any complete examples here? They will be much appreciated. Thanks.

    And similarly, I would love to see a simple example for custom line-by-line TXT datasets as well.

    Thank you again :)

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Autoencoder - Reducing the Dimensionality of Data with Neural Network

autoencoder Implementation of the Reducing the Dimensionality of Data with Neural Network – G. E. Hinton and R. R. Salakhutdinov paper. Notes Aim to m

Jordan Burgess 13 Nov 17, 2022
Simple Pose: Rethinking and Improving a Bottom-up Approach for Multi-Person Pose Estimation

SimplePose Code and pre-trained models for our paper, “Simple Pose: Rethinking and Improving a Bottom-up Approach for Multi-Person Pose Estimation”, a

Jia Li 256 Dec 24, 2022
Utilities to bridge Canvas-generated course rosters with GitLab's API.

gitlab-canvas-utils A collection of scripts originally written for CSE 13S. Oversees everything from GitLab course group creation, student repository

Eugene Chou 5 Jun 08, 2022
The repository for freeCodeCamp's YouTube course, Algorithmic Trading in Python

Algorithmic Trading in Python This repository Course Outline Section 1: Algorithmic Trading Fundamentals What is Algorithmic Trading? The Differences

Nick McCullum 1.8k Jan 02, 2023
An open source python library for automated feature engineering

"One of the holy grails of machine learning is to automate more and more of the feature engineering process." ― Pedro Domingos, A Few Useful Things to

alteryx 6.4k Jan 03, 2023
dataset for ECCV 2020 "Motion Capture from Internet Videos"

Motion Capture from Internet Videos Motion Capture from Internet Videos Junting Dong*, Qing Shuai*, Yuanqing Zhang, Xian Liu, Xiaowei Zhou, Hujun Bao

ZJU3DV 98 Dec 07, 2022
Fast, flexible and easy to use probabilistic modelling in Python.

Please consider citing the JMLR-MLOSS Manuscript if you've used pomegranate in your academic work! pomegranate is a package for building probabilistic

Jacob Schreiber 3k Dec 29, 2022
[CVPR 2021 Oral] Variational Relational Point Completion Network

VRCNet: Variational Relational Point Completion Network This repository contains the PyTorch implementation of the paper: Variational Relational Point

PL 121 Dec 12, 2022
Predict stock movement with Machine Learning and Deep Learning algorithms

Project Overview Stock market movement prediction using LSTM Deep Neural Networks and machine learning algorithms Software and Library Requirements Th

Naz Delam 46 Sep 13, 2022
Code for the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness"

DU-VAE This is the pytorch implementation of the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness" Acknowledgement

Dazhong Shen 4 Oct 19, 2022
EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation.

This repository contains data and code for our EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation. Please contact me at

9 Oct 28, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrai

Hugging Face 77.4k Jan 05, 2023
Generic Event Boundary Detection: A Benchmark for Event Segmentation

Generic Event Boundary Detection: A Benchmark for Event Segmentation We release our data annotation & baseline codes for detecting generic event bound

47 Nov 22, 2022
Code for paper "Extract, Denoise and Enforce: Evaluating and Improving Concept Preservation for Text-to-Text Generation" EMNLP 2021

The repo provides the code for paper "Extract, Denoise and Enforce: Evaluating and Improving Concept Preservation for Text-to-Text Generation" EMNLP 2

Yuning Mao 18 May 24, 2022
Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation.

AVATAR Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation. AVATAR stands for jAVA-pyThon progrAm tRanslation. AV

Wasi Ahmad 26 Dec 03, 2022
The open source code of SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation.

SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation(ICPR 2020) Overview This code is for the paper: Spatial Attention U-Net for Retinal V

Changlu Guo 151 Dec 28, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL: Graph Contrastive Learning for PyTorch PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL com

GCL: Graph Contrastive Learning Library for PyTorch 594 Jan 08, 2023
HSC4D: Human-centered 4D Scene Capture in Large-scale Indoor-outdoor Space Using Wearable IMUs and LiDAR. CVPR 2022

HSC4D: Human-centered 4D Scene Capture in Large-scale Indoor-outdoor Space Using Wearable IMUs and LiDAR. CVPR 2022 [Project page | Video] Getting sta

51 Nov 29, 2022
Simple-System-Convert--C--F - Simple System Convert With Python

Simple-System-Convert--C--F REQUIREMENTS Python version : 3 HOW TO USE Run the c

Jonathan Santos 2 Feb 16, 2022