Code for our ALiBi method for transformer language models.

Overview

Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

This repository contains the code and models for our paper Train Short, Test Long. This file explains how to run our experiments on the WikiText-103 dataset. Read the paper here.

Attention with Linear Biases (ALiBi) is very simple! Instead of adding position embeddings at the bottom of the transformer stack (which we don't) we add a linear bias to each attention score, as depicted in the figure above. The 'm' hyperparam is head-specific and is not learned- it is set at the beginning of training. We have a function that automatically generates these m values given the number of heads in the model.

ALiBi allows the model to be trained on, for example, 1024 tokens, and then do inference on 2048 (or much more) tokens without any finetuning. It's also able to improve performance, even when not extrapolating, in lower resource language modeling settings.

The implementation is very simple.

  1. Remove the position embeddings from the model: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L941
  2. Set up the relative bias matrix, here: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L742
  3. Add the bias matrix to the mask, which is then added in each attention score computation: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L1011
  4. (This might not be necessary in other frameworks.) Move the mask computation to before the layer loop, to make the transformer a tiny bit faster: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L949

Thats it!

Citation:

@misc{press2021train,
      title={Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation}, 
      author={Ofir Press and Noah A. Smith and Mike Lewis},
      year={2021},
      eprint={2108.12409},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

WikiText-103

Requirements and Installation

This repository is a fork of the Fairseq repository and so has the same requirements.

Once you've installed the dependencies, you can install this repository by running:

pip install --editable .

Preparing the data

To download and preprocess the data, run:

cd examples/language_model/
bash prepare-wikitext-103.sh
cd ../..


TEXT=examples/language_model/wikitext-103
python preprocess.py \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

Training and Inference

To train a language model with attention with linear baises (ALiBi), on input sequences with 512 tokens, run:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir wt103/  --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --no-epoch-checkpoints --tokens-per-sample 512 --max-tokens 9216 --update-freq 1  

For input sequences larger than 512 (and up to 2048) tokens, just change the --tokens-per-sample.

To train the model with inputs of 3072 tokens, the --update-freq parameter must be changed to 3 and the --max-tokens parameter must be reduced to 3072.

Saved Checkpoints

If you'd like to download our trained models on WikiText-103, they are available here:

Input Length Link
64 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L64.pt
128 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L128.pt
256 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L256.pt
512 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L512.pt
1024 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L1024.pt
1536 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L1536.pt
2048 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L2048.pt
3072 https://dl.fbaipublicfiles.com/train_short_test_long/wt103/alibi_wt103_L3072.pt

Rename the file you downloaded to checkpoint_best.pt if you'd like to follow the directions below.

Inference

For nonoverlapping evaluation of the validation set, run:

l=1024; fairseq-eval-lm data-bin/wikitext-103/     --path wt103/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1 --model-overrides "{'max_tokens':$l, 'tokens_per_sample':$l, 'max_target_positions':$l}"  --tokens-per-sample $l --max-tokens $l  --max-target-positions $l  --context-window 0

where l is set to the length of input subsequences during validation (l=1024 in the above example).

Owner
Ofir Press
PhD student @uwnlp
Ofir Press
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
The code for the NeurIPS 2021 paper "A Unified View of cGANs with and without Classifiers".

Energy-based Conditional Generative Adversarial Network (ECGAN) This is the code for the NeurIPS 2021 paper "A Unified View of cGANs with and without

sianchen 22 May 28, 2022
Codes for [NeurIPS'21] You are caught stealing my winning lottery ticket! Making a lottery ticket claim its ownership.

You are caught stealing my winning lottery ticket! Making a lottery ticket claim its ownership Codes for [NeurIPS'21] You are caught stealing my winni

VITA 8 Nov 01, 2022
[ICCV2021] Official code for "Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition"

CTR-GCN This repo is the official implementation for Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition. The pap

Yuxin Chen 148 Dec 16, 2022
The Self-Supervised Learner can be used to train a classifier with fewer labeled examples needed using self-supervised learning.

Published by SpaceML • About SpaceML • Quick Colab Example Self-Supervised Learner The Self-Supervised Learner can be used to train a classifier with

SpaceML 92 Nov 30, 2022
⚓ Eurybia monitor model drift over time and securize model deployment with data validation

View Demo · Documentation · Medium article 🔍 Overview Eurybia is a Python library which aims to help in : Detecting data drift and model drift Valida

MAIF 172 Dec 27, 2022
Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Şebnem 6 Jan 18, 2022
Official Pytorch implementation of MixMo framework

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks Official PyTorch implementation of the MixMo framework | paper | docs Alexandr

79 Nov 07, 2022
PyTorch implementation of "Representing Shape Collections with Alignment-Aware Linear Models" paper.

deep-linear-shapes PyTorch implementation of "Representing Shape Collections with Alignment-Aware Linear Models" paper. If you find this code useful i

Romain Loiseau 27 Sep 24, 2022
Learning to trade under the reinforcement learning framework

Trading Using Q-Learning In this project, I will present an adaptive learning model to trade a single stock under the reinforcement learning framework

Uirá Caiado 470 Nov 28, 2022
Adaptable tools to make reinforcement learning and evolutionary computation algorithms.

Pearl The Parallel Evolutionary and Reinforcement Learning Library (Pearl) is a pytorch based package with the goal of being excellent for rapid proto

38 Jan 01, 2023
Learning Saliency Propagation for Semi-supervised Instance Segmentation

Learning Saliency Propagation for Semi-supervised Instance Segmentation PyTorch Implementation This repository contains: the PyTorch implementation of

Berkeley DeepDrive 68 Oct 18, 2022
Learning Correspondence from the Cycle-consistency of Time (CVPR 2019)

TimeCycle Code for Learning Correspondence from the Cycle-consistency of Time (CVPR 2019, Oral). The code is developed based on the PyTorch framework,

Xiaolong Wang 706 Nov 29, 2022
Code for PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning

PackNet: https://arxiv.org/abs/1711.05769 Pretrained models are available here: https://uofi.box.com/s/zap2p03tnst9dfisad4u0sfupc0y1fxt Datasets in Py

Arun Mallya 216 Jan 05, 2023
A collection of random and hastily hacked together scripts for investigating EU-DCC

A collection of random and hastily hacked together scripts for investigating EU-DCC

Ryan Barrett 8 Mar 01, 2022
Multi-Horizon-Forecasting-for-Limit-Order-Books

Multi-Horizon-Forecasting-for-Limit-Order-Books This jupyter notebook is used to demonstrate our work, Multi-Horizon Forecasting for Limit Order Books

Zihao Zhang 116 Dec 23, 2022
Evaluating deep transfer learning for whole-brain cognitive decoding

Evaluating deep transfer learning for whole-brain cognitive decoding This README file contains the following sections: Project description Repository

Armin Thomas 5 Oct 31, 2022
This is the repo of the manuscript "Dual-branch Attention-In-Attention Transformer for speech enhancement"

DB-AIAT: A Dual-branch attention-in-attention transformer for single-channel SE

Guochen Yu 68 Dec 16, 2022
This is the code of paper ``Contrastive Coding for Active Learning under Class Distribution Mismatch'' with python.

Contrastive Coding for Active Learning under Class Distribution Mismatch Official PyTorch implementation of ["Contrastive Coding for Active Learning u

21 Dec 22, 2022
YuNetのPythonでのONNX、TensorFlow-Lite推論サンプル

YuNet-ONNX-TFLite-Sample YuNetのPythonでのONNX、TensorFlow-Lite推論サンプルです。 TensorFlow-LiteモデルはPINTO0309/PINTO_model_zoo/144_YuNetのものを使用しています。 Requirement Op

KazuhitoTakahashi 8 Nov 17, 2021