You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling

Related tags

Deep LearningYOSO
Overview

You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling

Transformer-based models are widely used in natural language processing (NLP). Central to the transformer model is the self-attention mechanism, which captures the interactions of token pairs in the input sequences and depends quadratically on the sequence length. Training such models on longer sequences is expensive. In this paper, we show that a Bernoulli sampling attention mechanism based on Locality Sensitive Hash- ing (LSH), decreases the quadratic complexity of such models to linear. We bypass the quadratic cost by considering self-attention as a sum of individual tokens associated with Bernoulli random variables that can, in principle, be sampled at once by a single hash (although in practice, this number may be a small constant). This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of LSH (to enable deployment on GPU architectures).

Requirements

docker, nvidia-docker

Start Docker Container

Under YOSO folder, run

docker run --ipc=host --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES= -v "$PWD:/workspace" -it mlpen/transformers:4

For Nvidia's 30 series GPU, run

docker run --ipc=host --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES= -v "$PWD:/workspace" -it mlpen/transformers:5

Then, the YOSO folder is mapped to /workspace in the container.

BERT

Datasets

To be updated

Pre-training

To start pre-training of a specific configuration: create a folder YOSO/BERT/models/ (for example, bert-small) and write YOSO/BERT/models/ /config.json to specify model and training configuration, then under YOSO/BERT folder, run

python3 run_pretrain.py --model 
   

   

The command will create a YOSO/BERT/models/ /model folder holding all checkpoints and log file.

Pre-training from Different Model's Checkpoint

Copy a checkpoint (one of .model or .cp file) from YOSO/BERT/models/ /model folder to YOSO/BERT/models/ folder and add a key-value pair in YOSO/BERT/models/ /config.json : "from_cp": " " . One example is shown in YOSO/BERT/models/bert-small-4096/config.json. This procedure also works for extending the max sequence length of a model (For example, use bert-small pre-trained weights as initialization for bert-small-4096).

GLUE Fine-tuning

Under YOSO/BERT folder, run

python3 run_glue.py --model 
   
     --batch_size 
    
      --lr 
     
       --task 
      
        --checkpoint 
        
       
      
     
    
   

For example,

python3 run_glue.py --model bert-small --batch_size 32 --lr 3e-5 --task MRPC --checkpoint cp-0249.model

The command will create a log file in YOSO/BERT/models/ /model .

Long Range Arena Benchmark

Datasets

To be updated

Run Evaluations

To start evaluation of a specific model on a task in LRA benchmark:

  • Create a folder YOSO/LRA/models/ (for example, softmax)
  • Write YOSO/LRA/models/ /config.json to specify model and training configuration

Under YOSO/LRA folder, run

python3 run_task.py --model 
   
     --task 
    

    
   

For example, run

python3 run_task.py --model softmax --task listops

The command will create a YOSO/LRA/models/ /model folder holding the best validation checkpoint and log file. After completion, the test set accuracy can be found in the last line of the log file.

RoBERTa

Datasets

To be updated

Pre-training

To start pretraining of a specific configuration:

  • Create a folder YOSO/RoBERTa/models/ (for example, bert-small)
  • Write YOSO/RoBERTa/models/ /config.json to specify model and training configuration

Under YOSO/RoBERTa folder, run

python3 run_pretrain.py --model 
   

   

For example, run

python3 run_pretrain.py --model bert-small

The command will create a YOSO/RoBERTa/models/ /model folder holding all checkpoints and log file.

GLUE Fine-tuning

To fine-tune model on GLUE tasks:

Under YOSO/RoBERTa folder, run

python3 run_glue.py --model 
   
     --batch_size 
    
      --lr 
     
       --task 
      
        --checkpoint 
        
       
      
     
    
   

For example,

python3 run_glue.py --model bert-small --batch_size 32 --lr 3e-5 --task MRPC --checkpoint 249

The command will create a log file in YOSO/RoBERTa/models/ /model .

Citation

@article{zeng2021yoso,
  title={You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling},
  author={Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh},
  booktitle={Proceedings of the International Conference on Machine Learning},
  year={2021}
}
Owner
Zhanpeng Zeng
Zhanpeng Zeng
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

Joshua Ji 3 Aug 20, 2022
It helps user to learn Pick-up lines and share if he has a better one

Pick-up-Lines-Generator(Open Source) It helps user to learn Pick-up lines Share and Add one or many to the DataBase Unique SQLite DataBase AI Undercon

knock_nott 0 May 04, 2022
Simulation-based inference for the Galactic Center Excess

Simulation-based inference for the Galactic Center Excess Siddharth Mishra-Sharma and Kyle Cranmer Abstract The nature of the Fermi gamma-ray Galactic

Siddharth Mishra-Sharma 3 Jan 21, 2022
一个目标检测的通用框架(不需要cuda编译),支持Yolo全系列(v2~v5)、EfficientDet、RetinaNet、Cascade-RCNN等SOTA网络。

一个目标检测的通用框架(不需要cuda编译),支持Yolo全系列(v2~v5)、EfficientDet、RetinaNet、Cascade-RCNN等SOTA网络。

Haoyu Xu 203 Jan 03, 2023
SlideGraph+: Whole Slide Image Level Graphs to Predict HER2 Status in Breast Cancer

SlideGraph+: Whole Slide Image Level Graphs to Predict HER2 Status in Breast Cancer A novel graph neural network (GNN) based model (termed SlideGraph+

28 Dec 24, 2022
code for TCL: Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022 News (03/16/2022) upload retrieval checkpoints finetuned on COCO and Flickr T

187 Jan 02, 2023
PyTea: PyTorch Tensor shape error analyzer

PyTea: PyTorch Tensor Shape Error Analyzer paper project page Requirements node.js = 12.x python = 3.8 z3-solver = 4.8 How to install and use # ins

ROPAS Lab. 240 Jan 02, 2023
Code & Data for Enhancing Photorealism Enhancement

Code & Data for Enhancing Photorealism Enhancement

Intel ISL (Intel Intelligent Systems Lab) 1.1k Jan 08, 2023
Kaggle competition: Springleaf Marketing Response

PruebaEnel Prueba Kaggle-Springleaf-master Prueba Kaggle-Springleaf Kaggle competition: Springleaf Marketing Response Competencia de Kaggle: Marketing

1 Feb 09, 2022
Anatomy of Matplotlib -- tutorial developed for the SciPy conference

Introduction This tutorial is a complete re-imagining of how one should teach users the matplotlib library. Hopefully, this tutorial may serve as insp

Matplotlib Developers 1.1k Dec 29, 2022
[AAAI2022] Source code for our paper《Suppressing Static Visual Cues via Normalizing Flows for Self-Supervised Video Representation Learning》

SSVC The source code for paper [Suppressing Static Visual Cues via Normalizing Flows for Self-Supervised Video Representation Learning] samples of the

7 Oct 26, 2022
Segmentation Training Pipeline

Segmentation Training Pipeline This package is a part of Musket ML framework. Reasons to use Segmentation Pipeline Segmentation Pipeline was developed

Musket ML 52 Dec 12, 2022
A Deep Learning based project for creating line art portraits.

ArtLine The main aim of the project is to create amazing line art portraits. Sounds Intresting,let's get to the pictures!! Model-(Smooth) Model-(Quali

Vijish Madhavan 3.3k Jan 07, 2023
SphereFace: Deep Hypersphere Embedding for Face Recognition

SphereFace: Deep Hypersphere Embedding for Face Recognition By Weiyang Liu, Yandong Wen, Zhiding Yu, Ming Li, Bhiksha Raj and Le Song License SphereFa

Weiyang Liu 1.5k Dec 29, 2022
A general python framework for visual object tracking and video object segmentation, based on PyTorch

PyTracking A general python framework for visual object tracking and video object segmentation, based on PyTorch. 📣 Two tracking/VOS papers accepted

2.6k Jan 04, 2023
The source codes for ACL 2021 paper 'BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data'

BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data This repository provides the implementation details for

124 Dec 27, 2022
Discord bot for notifying on github events

Git-Observer Discord bot for notifying on github events ⚠️ This bot is meant to write messages to only one channel (implementing this for multiple pro

ilu_vatar_ 0 Apr 19, 2022
Neural Reprojection Error: Merging Feature Learning and Camera Pose Estimation

Neural Reprojection Error: Merging Feature Learning and Camera Pose Estimation This is the official repository for our paper Neural Reprojection Error

Hugo Germain 78 Dec 01, 2022
IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales

IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales. In this case, we ended up using XGBoost because it was the o

1 Jan 04, 2022
PyTorch implementation for "Sharpness-aware Quantization for Deep Neural Networks".

Sharpness-aware Quantization for Deep Neural Networks This is the official repository for our paper: Sharpness-aware Quantization for Deep Neural Netw

Zhuang AI Group 30 Dec 19, 2022