Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Overview

Memory Compressed Attention

Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.

The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.

Install

$ pip install memory_compressed_attention

Usage

import torch
from memory_compressed_attention import MemoryCompressedAttention

attn = MemoryCompressedAttention(
    dim = 512,
    heads = 8,                 # number of heads
    causal = False,            # auto-regressive or not
    compression_factor = 3,    # compression ratio
    dropout = 0.1              # dropout post-attention
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

attn(x, input_mask = mask) # (1, 1024, 512)

Citations

@misc{liu2018generating,
    title={Generating Wikipedia by Summarizing Long Sequences},
    author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
    year={2018},
    eprint={1801.10198},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
You might also like...
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting.

GAN Memory for Lifelong learning This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting. Please consider citing our paper

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Comments
  • The order of masking and softmax operation

    The order of masking and softmax operation

    Hi,

    In memory_compressed_attention.py, I'm wondering if we need to do softmax operation after masking? Btw, if the entry in the mask should be float('-inf') instead of -float('-inf')? If I make something wrong, please correct me.

    image

    Thanks!

    opened by cfeng16 3
  • mask error in attention

    mask error in attention

    Very grateful for your pioneering work! I want to use it in Standard Transformer released in http://nlp.seas.harvard.edu/2018/04/03/attention.html. but it mat a mask error in training. more detail information shown as follow, the code i use: image class ConvCompress(nn.Module): def init(self, dim, ratio = 2, groups = 1): super(ConvCompress, self).init() self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups) #self.linear = nn.Linear(dim, dim)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)
    

    class MemoryCompressedAttention(nn.Module): def init(self, h, d_model, compression_factor = 2, dropout = 0.1): super(MemoryCompressedAttention, self).init() assert (d_model % h) == 0, 'dimension must be divisible by number of heads' self.h = h self.d_model = d_model self.d_k = d_model // h

        self.compression_factor = compression_factor
        self.compress_fn = ConvCompress(d_model, compression_factor, groups = h)
    
        #self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.wq = nn.Linear(d_model, d_model, bias = False)
        self.wk = nn.Linear(d_model, d_model, bias = False)
        self.wv = nn.Linear(d_model, d_model, bias = False)
    
        self.wo = nn.Linear(d_model, d_model)
    
        self.dropout = nn.Dropout(dropout)
    
        #self.null_k = nn.Parameter(torch.zeros(1, 1, d_model))
        #self.null_v = nn.Parameter(torch.zeros(1, 1, d_model))
    
    def forward(self, query, key, value, mask = None):
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        t = query.size(1)
        cf = self.compression_factor
    
        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)
    
        # make sure keys and values sequence lengths
        # are divisible by the compression factor
        padding = cf - (t % cf)
        if padding != 0:
            key, value = map(lambda t: F.pad(t, (0, 0, padding, 0)), (key, value))
    
    
        # compress keys and values
        key, value = map(self.compress_fn, (key, value))
    
        # attach a null key and value, in the case that the first query has no keys to pay attention to
        null_k = nn.Parameter(torch.zeros(key.size(0), 1, self.d_model)).cuda()
        null_v = nn.Parameter(torch.zeros(value.size(0), 1, self.d_model)).cuda()
    
        key = torch.cat((null_k, key), dim=1)
        value = torch.cat((null_v, value), dim=1)
        
        # merge heads
        #query, key, value = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (query, key, value))
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query = query.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        key = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        value = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
    
      
        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
    
        # 3) "Concat" using a view and apply a final linear.   # split heads and combine
        x = x.contiguous().view(nbatches, -1, self.d_model)
        out = self.wo(x)
    
        return out
    

    The error was show that image

    I want to know how to fix it, and how to do mask for N*M matrix??

    opened by HN123-123 0
Releases(0.0.5)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
SoK: Vehicle Orientation Representations for Deep Rotation Estimation

SoK: Vehicle Orientation Representations for Deep Rotation Estimation Raymond H. Tu, Siyuan Peng, Valdimir Leung, Richard Gao, Jerry Lan This is the o

FIRE Capital One Machine Learning of the University of Maryland 12 Oct 07, 2022
Python periodic table module

elemenpy Hello! elements.py is a small Python periodic table module that is used for calling certain information about an element. Installation Instal

Eric Cheng 2 Dec 27, 2021
Replication Package for AequeVox:Automated Fariness Testing for Speech Recognition Systems

AequeVox Replication Package for AequeVox:Automated Fariness Testing for Speech Recognition Systems README under development. Python Packages Required

Sai Sathiesh 2 Aug 28, 2022
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
Code release for "Self-Tuning for Data-Efficient Deep Learning" (ICML 2021)

Self-Tuning for Data-Efficient Deep Learning This repository contains the implementation code for paper: Self-Tuning for Data-Efficient Deep Learning

THUML @ Tsinghua University 101 Dec 11, 2022
Learnable Multi-level Frequency Decomposition and Hierarchical Attention Mechanism for Generalized Face Presentation Attack Detection

LMFD-PAD Note This is the official repository of the paper: LMFD-PAD: Learnable Multi-level Frequency Decomposition and Hierarchical Attention Mechani

28 Dec 02, 2022
This is an early in-development version of training CLIP models with hivemind.

A transformer that does not hog your GPU memory This is an early in-development codebase: if you want a stable and documented hivemind codebase, look

<a href=[email protected]"> 4 Nov 06, 2022
Code, Models and Datasets for OpenViDial Dataset

OpenViDial This repo contains downloading instructions for the OpenViDial dataset in 《OpenViDial: A Large-Scale, Open-Domain Dialogue Dataset with Vis

119 Dec 08, 2022
Implementation of "Efficient Regional Memory Network for Video Object Segmentation" (Xie et al., CVPR 2021).

RMNet This repository contains the source code for the paper Efficient Regional Memory Network for Video Object Segmentation. Cite this work @inprocee

Haozhe Xie 76 Dec 14, 2022
Dist2Dec: A Simplicial Neural Network for Homology Localization

Dist2Dec: A Simplicial Neural Network for Homology Localization

Alexandros Keros 6 Jun 12, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

TUCH This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] License Software Copyright License fo

Lea Müller 45 Jan 07, 2023
This repository contains a toolkit for collecting, labeling and tracking object keypoints

This repository contains a toolkit for collecting, labeling and tracking object keypoints. Object keypoints are semantic points in an object's coordinate frame.

ETHZ ASL 13 Dec 12, 2022
Rename Images with Auto Generated Neural Image Captions

Recaption Images with Generated Neural Image Caption Example Usage: Commandline: Recaption all images from folder /home/feng/Downloads/images to folde

feng wang 3 May 01, 2022
A high performance implementation of HDBSCAN clustering.

HDBSCAN HDBSCAN - Hierarchical Density-Based Spatial Clustering of Applications with Noise. Performs DBSCAN over varying epsilon values and integrates

2.3k Jan 02, 2023
An Industrial Grade Federated Learning Framework

DOC | Quick Start | 中文 FATE (Federated AI Technology Enabler) is an open-source project initiated by Webank's AI Department to provide a secure comput

Federated AI Ecosystem 4.8k Jan 09, 2023
SciPy fixes and extensions

scipyx SciPy is large library used everywhere in scientific computing. That's why breaking backwards-compatibility comes as a significant cost and is

Nico Schlömer 16 Jul 17, 2022
A deep learning based semantic search platform that computes similarity scores between provided query and documents

semanticsearch This is a deep learning based semantic search platform that computes similarity scores between provided query and documents. Documents

1 Nov 30, 2021
Official code for "On the Frequency Bias of Generative Models", NeurIPS 2021

Frequency Bias of Generative Models Generator Testbed Discriminator Testbed This repository contains official code for the paper On the Frequency Bias

35 Nov 01, 2022
This is the pytorch implementation for the paper: *Learning Accurate Performance Predictors for Ultrafast Automated Model Compression*, which is in submission to TPAMI

SeerNet This is the pytorch implementation for the paper: Learning Accurate Performance Predictors for Ultrafast Automated Model Compression, which is

3 May 01, 2022
This repository contains numerical implementation for the paper Intertemporal Pricing under Reference Effects: Integrating Reference Effects and Consumer Heterogeneity.

This repository contains numerical implementation for the paper Intertemporal Pricing under Reference Effects: Integrating Reference Effects and Consumer Heterogeneity.

Hansheng Jiang 6 Nov 18, 2022