Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

Overview

Memorizing Transformers - Pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.

Install

$ pip install memorizing-transformers-pytorch

Usage

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,                 # number of tokens
    dim = 512,                          # dimension
    dim_head = 64,                      # dimension per attention head
    depth = 8,                          # number of layers
    memorizing_layers = (4, 5),         # which layers to have ANN memories
    max_knn_memories = 64000,           # maximum ANN memories to keep (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries)
    num_retrieved_memories = 32,        # number of ANN memories to retrieve
    clear_memories_on_sos_token_id = 1, # clear passed in ANN memories automatically for batch indices which contain this specified SOS token id - otherwise, you can also manually iterate through the ANN memories and clear the indices before the next iteration
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

knn_memories = model.create_knn_memories(batch_size = 2) # create collection of KNN memories with the correct batch size (2 in example)

logits = model(data, knn_memories = knn_memories) # (1, 1024, 20000)

You can make the KNN memories read-only by setting add_knn_memory on forward to False

ex.

logits = model(data, knn_memories = knn_memories, add_knn_memory = False) # knn memories will not be updated

With Transformer-XL memories (only the memories that will be discarded will be added to the KNN memory)

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 8,
    memorizing_layers = (4, 5),
    max_knn_memories = 64000,
    num_retrieved_memories = 32,
    clear_memories_on_sos_token_id = 1,
    xl_memory_layers = (2, 3, 4, 5),      # xl memory layers - (https://arxiv.org/abs/2007.03356 shows you do not need XL memory on all layers, just the latter ones) - if a KNNAttention layer ends up using XL memories, only the XL memories that will be discarded will be added to long term memory
    xl_max_memories = 512,                # number of xl memories to keep
    shift_knn_memories_down = 1,          # let a layer look at the KNN memories this number of layers above
    shift_xl_memories_down = 1,           # let a layer look at the XL memories this number of layers above, shown to enhance receptive field in ernie-doc paper
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

xl_memories = None

with model.knn_memories_context(batch_size = 2) as knn_memories:
    logits1, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits2, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits3, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)

    # ... and so on

KNN Memory

This repository contains a wrapper around Faiss that can automatically store and retrieve key / values

import torch
from memorizing_transformers_pytorch import KNNMemory

memory = KNNMemory(
    dim = 64,                   # dimension of key / values
    max_memories = 64000,       # maximum number of memories to keep (will throw out the oldest memories for now if it overfills)
    num_indices = 2             # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document
)

memory.add(torch.randn(2, 512, 2, 64))  # (batch, seq, key | value, feature dim)
memory.add(torch.randn(2, 512, 2, 64))

memory.clear([0]) # clear batch 0, if it saw an <sos>

memory.add(torch.randn(2, 512, 2, 64))
memory.add(torch.randn(2, 512, 2, 64))

key_values, mask = memory.search(torch.randn(2, 512, 64), topk = 32)

Training

Enwik8 training

$ python train.py

Todo

  • switch to ivfhnsw and just remember all memories
  • enwik8 demo
  • validation for enwik8
  • solve gradient accumulation problem by offering some way to scope reads and writes to knn memories with another indices array
  • setup text generation with memories
  • figure out how to deal with memories efficiently once capacity has been hit
  • try to speed up reading and writing to knn memories collection with multiprocessing

Citations

@article{wu2022memorizing,
  title   = {Memorizing transformers},
  author  = {Wu, Yuhuai and Rabe, Markus N and Hutchins, DeLesley and Szegedy, Christian},
  journal = {arXiv preprint arXiv:2203.08913},
  year    = {2022}
}
@article{Shazeer2019FastTD,
  title   = {Fast Transformer Decoding: One Write-Head is All You Need},
  author  = {Noam M. Shazeer},
  journal = {ArXiv},
  year    = {2019},
  volume  = {abs/1911.02150}
}
@Article{AlphaFold2021,
  author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
  journal = {Nature},
  title   = {Highly accurate protein structure prediction with {AlphaFold}},
  year    = {2021},
  doi     = {10.1038/s41586-021-03819-2},
  note    = {(Accelerated article preview)},
}
@inproceedings{Rae2020DoTN,
  title   = {Do Transformers Need Deep Long-Range Memory?},
  author  = {Jack W. Rae and Ali Razavi},
  booktitle = {ACL},
  year    = {2020}
}
@misc{ding2021erniedoc,
  title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
  author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
  year    = {2021},
  eprint  = {2012.15688},
  archivePrefix = {arXiv},
  primaryClass = {cs.CL}
}
@misc{henry2020querykey,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
    year    = {2020},
    eprint  = {2010.04245},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

Memory is Attention through Time - Alex Graves

Comments
  • Arguments to reproduce the models from the original paper?

    Arguments to reproduce the models from the original paper?

    Hi lucidrains,

    This looks like excellent work! I have gone through the original paper and your repo, and am now trying to reproduce the model from the paper as closely as possible. Of course, the modifications you made such as hybrid attention instead of sigmoid gate are fine.

    Specifically, I would like to be able to try some of the variations in Table 4: image

    Suppose I'm interested in the 4th to last row with Context 512 Memory 8192 XL cache 512. Can you help me the model arguments to do that? Here is my initial attempt, with reference to Section 4.2:

    model = MemorizingTransformer(
        num_tokens = 32000, # vocab 32k
        dim = 1024, 
        depth = 12,
        memorizing_layers = 9,
        max_knn_memories = 8192, # Memory column
        num_retrieved_memories = 32,
        clear_memories_on_sos_token_id = 1,
        xl_memory_layers = (6, 7, 8, 9),  # not sure about this?
        xl_max_memories = 512, # XL cache column
        shift_knn_memories_down = 1, 
        shift_xl_memories_down = 1,
        # which argument corresponds to Context column?
    ).cuda()
    
    

    A second question is what are the model arguments to reproduce to first row of Table 4, with no memory nor XL cache? Thanks in advance.

    opened by manestay 1
  • KNNMemory add() does not appear to update self.knns

    KNNMemory add() does not appear to update self.knns

    Thanks for the nice implementation. I've adapted this code for my own use, so I don't have the whole stack that would reproduce this bug. However, you can check for yourself.

    The following code ought to update the KNN objects in the KNNMemory class:

    @delayed
    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
    
    Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
    

    [link to that code here]

    However, even after repeated calls to add to the memory, calling KNNMemory.search() results in empty values. If you view self.knns at this point, self.is_trained remains False.

    When I modify the code as follows, this fixes the issue.

    @delayed
    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
        return knn
    
    updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
    self.knns = updated_knns
    

    This will allow searches to return actual values.

    opened by vyaivo 0
  • FAISS hard reset

    FAISS hard reset

    Hello and thanks for this implementation!

    Do you know of any solutions to efficiently solve the "hard reset" problem in FAISS? I know that one could use IndexFlatL2 but that's not really efficient.

    Thank you!

    opened by itsdaniele 0
  •  index out of

    index out of

    when I run train.py, error like this ,"index out of range: Tried to access index 10218 out of table with 255 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418"happens

    opened by chxiag 0
  • Support for Multi-GPU training?

    Support for Multi-GPU training?

    Thank you so much for the great implementation. I would like to ask whether your implementation for Memorizing Transformer could support multi-card distributed training like original paper. If you distribute the memorizingtrransformer model you created to each GPU, then every GPU would hold a memory with a retrieval faiss index. Therefore, each model on different GPU holds different memory database and retrieval index, which is different from the original paper. I regard that each model on different GPU should share the same retrieval context. This problem confuses me a lot.

    Thank you so much for your time. Looking forward to your response!

    opened by Victorwz 0
  • Dimensionality of key and values for Attention

    Dimensionality of key and values for Attention

    I have two questions about the key and value calculation in Attention (and similarly for KNNAttention).

    The relevant line is: https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L135

    1. Why is there only one Linear layer to_kv, instead of 2 linear layers to_k and to_v?
    2. Why is the last dimension dim_head*2? I get that *2 is for both k and v, but what about dim_head? I thought q, k, v should all have the same final dimension (i.e. inner_dim==dim_head*heads). My understanding is that this means that either a) there is only 1 attention head, or for b) all heads, k and v are shared. Is there a reason this is done, or am I misunderstanding?

    In your Attention class for Performer, q, k, v all have the same dimensions.

    Thanks in advance!

    opened by manestay 8
  • Maybe scale is wrong

    Maybe scale is wrong

    https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L237

    Shouldn't this be (1-scale)?

    opened by denadai2 3
Releases(0.3.10)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

EleutherAI 96 Dec 21, 2022
LSTM model - IMDB review sentiment analysis

NLP - Movie review sentiment analysis The colab notebook contains the code for building a LSTM Recurrent Neural Network that gives 87-88% accuracy on

Sundeep Bhimireddy 1 Jan 29, 2022
Recognition of 38 speech commands in russian. Based on Yandex Cup 2021 ML Challenge: ASR

Speech_38_ru_commands Recognition of 38 speech commands in russian. Based on Yandex Cup 2021 ML Challenge: ASR Программа умеет распознавать 38 ключевы

Andrey 9 May 05, 2022
Open solution to the Toxic Comment Classification Challenge

Starter code: Kaggle Toxic Comment Classification Challenge More competitions 🎇 Check collection of public projects 🎁 , where you can find multiple

minerva.ml 153 Jun 22, 2022
Pre-training with Extracted Gap-sentences for Abstractive SUmmarization Sequence-to-sequence models

PEGASUS library Pre-training with Extracted Gap-sentences for Abstractive SUmmarization Sequence-to-sequence models, or PEGASUS, uses self-supervised

Google Research 1.4k Dec 22, 2022
Unsupervised text tokenizer focused on computational efficiency

YouTokenToMe YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently implements fast Byte Pair Encoding (BPE)

VK.com 847 Dec 19, 2022
The NewSHead dataset is a multi-doc headline dataset used in NHNet for training a headline summarization model.

This repository contains the raw dataset used in NHNet [1] for the task of News Story Headline Generation. The code of data processing and training is available under Tensorflow Models - NHNet.

Google Research Datasets 31 Jul 15, 2022
The guide to tackle with the Text Summarization

The guide to tackle with the Text Summarization

Takahiro Kubo 1.2k Dec 30, 2022
Open-Source Toolkit for End-to-End Speech Recognition leveraging PyTorch-Lightning and Hydra.

OpenSpeech provides reference implementations of various ASR modeling papers and three languages recipe to perform tasks on automatic speech recogniti

Soohwan Kim 26 Dec 14, 2022
Final Project for the Intel AI Readiness Boot Camp NLP (Jan)

NLP Boot Camp (Jan) Synopsis Full Name: Prameya Mohanty Name of your School: Delhi Public School, Rourkela Class: VIII Title of the Project: iTransect

TheCodingHub 1 Feb 01, 2022
Conditional probing: measuring usable information beyond a baseline

Conditional probing: measuring usable information beyond a baseline

John Hewitt 20 Dec 15, 2022
ConferencingSpeech2022; Non-intrusive Objective Speech Quality Assessment (NISQA) Challenge

ConferencingSpeech 2022 challenge This repository contains the datasets list and scripts required for the ConferencingSpeech 2022 challenge. For more

21 Dec 02, 2022
Easy Language Model Pretraining leveraging Huggingface's Transformers and Datasets

Easy Language Model Pretraining leveraging Huggingface's Transformers and Datasets What is LASSL • How to Use What is LASSL LASSL은 LAnguage Semi-Super

LASSL: LAnguage Self-Supervised Learning 116 Dec 27, 2022
Source code for the paper "TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations"

TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations Created by Jiahao Pang, Duanshun Li, and Dong Tian from InterDigital In

InterDigital 21 Dec 29, 2022
Vad-sli-asr - A Python scripts for a speech processing pipeline with Voice Activity Detection (VAD)

VAD-SLI-ASR Python scripts for a speech processing pipeline with Voice Activity

Dynamics of Language 14 Dec 09, 2022
GPT-Code-Clippy (GPT-CC) is an open source version of GitHub Copilot, a language model

GPT-Code-Clippy (GPT-CC) is an open source version of GitHub Copilot, a language model -- based on GPT-3, called GPT-Codex -- that is fine-tuned on publicly available code from GitHub.

Nathan Cooper 2.3k Jan 01, 2023
Winner system (DAMO-NLP) of SemEval 2022 MultiCoNER shared task over 10 out of 13 tracks.

KB-NER: a Knowledge-based System for Multilingual Complex Named Entity Recognition The code is for the winner system (DAMO-NLP) of SemEval 2022 MultiC

116 Dec 27, 2022
LCG T-TEST USING EUCLIDEAN METHOD

This project has been created for statistical usage, purposing for determining ATL takers and nontakers using LCG ttest and Euclidean Method, especially for internal business case in Telkomsel.

2 Jan 21, 2022
Compute distance between sequences. 30+ algorithms, pure python implementation, common interface, optional external libs usage.

TextDistance TextDistance -- python library for comparing distance between two or more sequences by many algorithms. Features: 30+ algorithms Pure pyt

Life4 3k Jan 06, 2023
MEDIALpy: MEDIcal Abbreviations Lookup in Python

A small python package that allows the user to look up common medical abbreviations.

Aberystwyth Systems Biology 7 Nov 09, 2022