[NeurIPS 2020] Official Implementation: "SMYRF: Efficient Attention using Asymmetric Clustering".

Related tags

Deep Learningsmyrf
Overview

SMYRF: Efficient attention using asymmetric clustering

Get started:

Colab

Abstract

We propose a novel type of balanced clustering algorithm to approximate attention. Attention complexity is reduced from O(N^2) to O(NlogN), where N is the sequence length. Our algorithm, SMYRF, uses Locality Sensitive Hashing (LSH) in a novel way by defining new Asymmetric transformations and an adaptive scheme that produces balanced clusters. The biggest advantage of SMYRF is that it can be used as a drop-in replacement for dense attention layers without any retraining. On the contrary, prior fast attention methods impose constraints (e.g. tight queries and keys) and require re-training from scratch. We apply our method to pre-trained state-of-the-art Natural Language Processing and Computer Vision models and we report significant memory and speed benefits. Notably, SMYRF-BERT outperforms (slightly) BERT on GLUE, while using $50%$ less memory. We also show that SMYRF can be used interchangeably with dense attention before and after training. Finally, we use SMYRF to train GANs with attention in high resolutions. Using a single TPU, we train BigGAN on Celeba-HQ, with attention at resolution 128x128 and 256x256, capable of generating realistic human faces.

Authors: Giannis Daras, Nikita Kitaev, Augustus Odena, Alexandros G. Dimakis

Results

Memory-quality trade-off

GLUE benchmark

Avg. # C CoLA MNLI-m/mm MRPC QNLI QQP RTE SST-2 STS-B
BERT128 82.69 1 1 57.83 84.43/84.68 88.41 91.31 89.70 65.70 93.46 88.73
SMYRF-BERT2x32 82.98 2 32 58.79 83.76/84.27 87.69 91.14 89.72 68.59 93.23 89.65
SMYRF-BERT2x16 81.74 2 16 58.90 82.86/83.49 85.72 89.53 89.33 64.98 93.12 87.75
BERT64 81.57 1 64 58.80 82.34/82.47 87.02 90.48 89.69 61.73 93.00 88.64
BERT32 73.56 1 32 56.40 64.51/63.41 77.89 79.81 88.59 55.23 92.66 83.53

Interchangeability of SMYRF and dense attention

Results on IMDB dataset. Using dense attention on inference consistently improves results, nearly matching dense attention perf.

Memory SMYRF Inference Accuracy
RoBERTa 100% 94.96%
SMYRF-RoBERTa 50% 93.72%
SMYRF-RoBERTa 50% 94.62%
BERT 100% 94.12%
SMYRF-BERT 50% 92.64%
SMYRF-BERT 50% 93.54%

Smyrf-BigGAN training on Celeba-HQ-128

Generated faces by a Smyrf-BigGAN trained on 128x128 resolution with attention at 128x128, using 50% of dense memory.

Results after 120k iterations:

Resolution Attention # C FID
BigGAN 128x128 64x64 1 4096 26.06
Smyrf-BigGAN 128x128 128x128 4 2048 25.03

where # denotes number of hashes and C number of queries per cluster.

What's here

The code hosted in this repository is the one we used to run all the experiments in the paper. Get started:

Colab

For a deeper dive, look at the examples/ folder where we have code for pre-training SMYRF-BigGAN, sampling from a pre-trained BigGAN with SMYRF, finetuning state-of-the-art NLP models with SMYRF and a lot more.

Acknowledgments

We would like to wholeheartedly thank the TensorFlow Research Cloud (TFRC) program that gave us access to Cloud TPUs and GCP credits to train our models.

The code for the NLP experiments is exclusively based on the HuggingFace transformers library. We are very grateful to the authors of the library for their work.

The code for the CV experiments is based on the PyTorch implementation of BigGAN available in this url. The code has been expanded to support training on TPUs. Again, we want to thank the author for open-sourcing this implementation.

You might also like...
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Code for Discriminative Sounding Objects Localization (NeurIPS 2020)
Code for Discriminative Sounding Objects Localization (NeurIPS 2020)

Discriminative Sounding Objects Localization Code for our NeurIPS 2020 paper Discriminative Sounding Objects Localization via Self-supervised Audiovis

Advances in Neural Information Processing Systems (NeurIPS), 2020.

What is being transferred in transfer learning? This repo contains the code for the following paper: Behnam Neyshabur*, Hanie Sedghi*, Chiyuan Zhang*.

Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)
Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)

Neuron Merging: Compensating for Pruned Neurons Pytorch implementation of Neuron Merging: Compensating for Pruned Neurons, accepted at 34th Conference

Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Defending graph neural networks against adversarial attacks (NeurIPS 2020)
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ([email protected]), Marinka Zitnik ([email protected].

Code for the Population-Based Bandits Algorithm, presented at NeurIPS 2020.

Population-Based Bandits (PB2) Code for the Population-Based Bandits (PB2) Algorithm, from the paper Provably Efficient Online Hyperparameter Optimiza

Code release for NeurIPS 2020 paper "Co-Tuning for Transfer Learning"

CoTuning Official implementation for NeurIPS 2020 paper Co-Tuning for Transfer Learning. [News] 2021/01/13 The COCO 70 dataset used in the paper is av

Discovering Interpretable GAN Controls [NeurIPS 2020]
Discovering Interpretable GAN Controls [NeurIPS 2020]

GANSpace: Discovering Interpretable GAN Controls Figure 1: Sequences of image edits performed using control discovered with our method, applied to thr

Comments
  • Auto-regressive

    Auto-regressive

    Hi Giannis!

    Thanks for the great paper! I am interested in your asymmetric LSH, as I think having separate query / key space (as opposed to shared QK as in Reformer) will bring performance improvements in LSH-based attention.

    I saw that you recommended to a previous user to use this form of clustering for the auto-regressive case, and just wanted to probe if you had considered the scenario where a bucket of queries do not get matched with any keys from the past at all. This was an issue I had with trying to make separate QK space work with routing transformer, but just wondering if you had identified and found a solution to this problem.

    Phil

    opened by lucidrains 2
  • Logging and scoring

    Logging and scoring

    Currently logging and scoring is disabled for TPU BigGAN for maximum efficiency. We can probably re-write the logger and scorer to lower their performance bottleneck by converting most cpu materializations to XLA ops.

    bug example 
    opened by giannisdaras 0
  • Ema not working on TPU

    Ema not working on TPU

    Exponential moving average on weights of G is not working on TPUs. The problem is related to the loading of the state dict: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/utils.py#L614

    For now, we disable ema.

    bug example 
    opened by giannisdaras 0
Releases(1.0)
Owner
Giannis Daras
Machine Learning Researcher. Ph.D. student, UT Austin.
Giannis Daras
Keras + Hyperopt: A very simple wrapper for convenient hyperparameter optimization

This project is now archived. It's been fun working on it, but it's time for me to move on. Thank you for all the support and feedback over the last c

Max Pumperla 2.1k Jan 03, 2023
Configure SRX interfaces with Scrapli

Configure SRX interfaces with Scrapli Overview This example will show how to configure interfaces on Juniper's SRX firewalls. In addition to the Pytho

Calvin Remsburg 1 Jan 07, 2022
This is an official implementation of the CVPR2022 paper "Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots".

Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots Blind2Unblind Citing Blind2Unblind @inproceedings{wang2022blind2unblind, tit

demonsjin 58 Dec 06, 2022
Apache Flink

Apache Flink Apache Flink is an open source stream processing framework with powerful stream- and batch-processing capabilities. Learn more about Flin

The Apache Software Foundation 20.4k Dec 30, 2022
App customer segmentation cohort rfm clustering

CUSTOMER SEGMENTATION COHORT RFM CLUSTERING TỔNG QUAN VỀ HỆ THỐNG DỮ LIỆU Nên chuyển qua theme màu dark thì sẽ nhìn đẹp hơn https://customer-segmentat

hieulmsc 3 Dec 18, 2021
SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers

SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers This repo contains our codes for the paper "No Parameters Left Behind: Sensitivity Gu

Chen Liang 23 Nov 07, 2022
[2021 MultiMedia] CONQUER: Contextual Query-aware Ranking for Video Corpus Moment Retrieval

CONQUER: Contexutal Query-aware Ranking for Video Corpus Moment Retreival PyTorch implementation of CONQUER: Contexutal Query-aware Ranking for Video

Hou zhijian 23 Dec 26, 2022
Accelerated NLP pipelines for fast inference on CPU and GPU. Built with Transformers, Optimum and ONNX Runtime.

Optimum Transformers Accelerated NLP pipelines for fast inference 🚀 on CPU and GPU. Built with 🤗 Transformers, Optimum and ONNX runtime. Installatio

Aleksey Korshuk 115 Dec 16, 2022
Official implementation of YOGO for Point-Cloud Processing

You Only Group Once: Efficient Point-Cloud Processing with Token Representation and Relation Inference Module By Chenfeng Xu, Bohan Zhai, Bichen Wu, T

Chenfeng Xu 67 Dec 20, 2022
Mapping Conditional Distributions for Domain Adaptation Under Generalized Target Shift

This repository contains the official code of OSTAR in "Mapping Conditional Distributions for Domain Adaptation Under Generalized Target Shift" (ICLR 2022).

Matthieu Kirchmeyer 5 Dec 06, 2022
Library to enable Bayesian active learning in your research or labeling work.

Bayesian Active Learning (BaaL) BaaL is an active learning library developed at ElementAI. This repository contains techniques and reusable components

ElementAI 687 Dec 25, 2022
TensorFlow implementation of AlexNet and its training and testing on ImageNet ILSVRC 2012 dataset

AlexNet training on ImageNet LSVRC 2012 This repository contains an implementation of AlexNet convolutional neural network and its training and testin

Matteo Dunnhofer 161 Nov 25, 2022
The Rich Get Richer: Disparate Impact of Semi-Supervised Learning

The Rich Get Richer: Disparate Impact of Semi-Supervised Learning Preprocess file of the dataset used in implicit sub-populations: (Demographic groups

<a href=[email protected]"> 4 Oct 14, 2022
Example-custom-ml-block-keras - Custom Keras ML block example for Edge Impulse

Custom Keras ML block example for Edge Impulse This repository is an example on

Edge Impulse 8 Nov 02, 2022
This repo contains the code required to train the multivariate time-series Transformer.

Multi-Variate Time-Series Transformer This repo contains the code required to train the multivariate time-series Transformer. Download the data The No

Gregory Duthé 4 Nov 24, 2022
A Comprehensive Study on Learning-Based PE Malware Family Classification Methods

A Comprehensive Study on Learning-Based PE Malware Family Classification Methods Datasets Because of copyright issues, both the MalwareBazaar dataset

8 Oct 21, 2022
Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks

OnsagerNet Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks This is the original pyTorch implemenati

Haijun.Yu 3 Aug 24, 2022
A library that can print Python objects in human readable format

objprint A library that can print Python objects in human readable format Install pip install objprint Usage op Use op() (or objprint()) to print obj

319 Dec 25, 2022
Revisting Open World Object Detection

Revisting Open World Object Detection Installation See INSTALL.md. Dataset Our new data division is based on COCO2017. We divide the training set into

58 Dec 23, 2022
Dilated Convolution for Semantic Image Segmentation

Multi-Scale Context Aggregation by Dilated Convolutions Introduction Properties of dilated convolution are discussed in our ICLR 2016 conference paper

Fisher Yu 764 Dec 26, 2022