Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.

Overview

Shortformer

This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the WikiText-103 dataset. Read the full paper here.

The Shortformer is a combination of two methods:

  1. Staged Training: We first train the model on short input subsequences and then train it on longer ones. This improves both train speed and evaluation perplexity.
  2. Position-Infused Attention + Caching: We cache previously computed subsequence representations and attend to them using Position-Infused Attention. Position-Infused Attention modifies the model so that position embeddings are not added to the word embeddings at the bottom of the network, but instead, they are added to the keys and queries in the attention sublayer (but not to the values). We show that PIA + caching vastly speeds up generation and also improves perplexity.

Staged training requires no modification to the original code. To see how we implemented the Position-Infused Attention and caching, click here. Implementing PIA and caching is very easy, and we've provided detailed comments in the code to explain what how we did it.

If you use this code or results from our paper, please cite:

@misc{press2020shortformer,
      title={Shortformer: Better Language Modeling using Shorter Inputs}, 
      author={Ofir Press and Noah A. Smith and Mike Lewis},
      year={2020},
      eprint={2012.15832},
}

Requirements and Installation

This repository is a fork of the Fairseq repository and so has the same requirements.

Once you've installed the dependencies, you can install this repository by running:

pip install --editable .

Preparing the data

To download and preprocess the data, run:

cd examples/language_model/
bash prepare-wikitext-103.sh
cd ../..


TEXT=examples/language_model/wikitext-103
python preprocess.py \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

Train/Inference for the different models

Shortformer

Our Shortformer model takes the baseline and adds caching, Position-Infused Attention, and Staged Training.

To train the first stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints128e100/     --arch transformer_lm_wiki103     --max-update 140100 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128 --max-tokens-valid 128 --tokens-from-prev 128 --curriculum 1000 --required-batch-size-multiple 1 --save-interval 100

If your GPUs don't have enough memory to execute that command, you can set --update-freq to 2 and --max-tokens to 4608, or set --update-freq to 3 and --max-tokens to 3072 for running the model with even lower memory constraints. This chunks the batch into 2 or 3 different parts and computes each part seperately (instead of in parallel), so it uses less memory but runs slower.

After that, to train the model with the second (and final) stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir shortformer/ --restore-file checkpoints128e100/checkpoint100.pt     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

Again, you can use the update-freq/max-tokens method from above if you run out of memory.

Saved Checkpoint

If you'd like to download the Shortformer instead of training it, it is available here. Rename that file to checkpoint_best.pt if you'd like to follow the directions below.

Inference

For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1

For token-by-token generation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1 --sliding-inf 1 --context-window 511 --max-tokens 512

(Note that --context-window is a fairseq command and doesn't have the exact meaning that the term "context window" has in our paper.)

Shortformer (without Staged Training)

Staged training improves the perplexity of the model and makes training faster, so there's no reason not to use it, but if you would like to train the Shortformer without it, the command is

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir shortformer-no-st/      --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

For inference, use the same commands as the ones for the Shortformer (above).

Baseline with Staged Training

Our Shortformer model is fast to train and for token-by-token generation, but if speed is not an issue, we can achieve slightly better performance by just applying Staged Training to the Baevski & Auli baseline LM. This model is very slow but achieves the best perplexity.

To train the first stage, download the unmodified fairseq reporsitory and then run:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints-st-128e50/     --arch transformer_lm_wiki103     --max-update 70050 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128  --required-batch-size-multiple 1 --save-interval 50

After that, to train the model with the second (and final) stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir st/ --restore-file checkpoints-st-128e50/checkpoint50.pt     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints

Inference

For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1

For sliding window evaluation of the validation set, with a stride of 2,560, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1 --context-window 2560

Baseline - Baevski & Auli

To train the baseline, download the unmodified fairseq repository and then run:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir baseline/  --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints

Inference

Use the same commands as in the 'Baseline with Staged Training' inference subsection.

Owner
Ofir Press
PhD student @uwnlp
Ofir Press
Demo for the paper "Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation"

Streaming speaker diarization Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation by Juan Manuel Coria, Hervé

Juanma Coria 187 Jan 06, 2023
PyTorch reimplementation of REALM and ORQA

PyTorch reimplementation of REALM and ORQA

Li-Huai (Allan) Lin 17 Aug 20, 2022
Accelerated Multi-Modal MR Imaging with Transformers

Accelerated Multi-Modal MR Imaging with Transformers Dependencies numpy==1.18.5 scikit_image==0.16.2 torchvision==0.8.1 torch==1.7.0 runstats==1.8.0 p

54 Dec 16, 2022
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Apache MXNet (incubating) for Deep Learning Apache MXNet is a deep learning framework designed for both efficiency and flexibility. It allows you to m

The Apache Software Foundation 20.2k Jan 05, 2023
Companion code for the paper "An Infinite-Feature Extension for Bayesian ReLU Nets That Fixes Their Asymptotic Overconfidence" (NeurIPS 2021)

ReLU-GP Residual (RGPR) This repository contains code for reproducing the following NeurIPS 2021 paper: @inproceedings{kristiadi2021infinite, title=

Agustinus Kristiadi 4 Dec 26, 2021
Large-Scale Pre-training for Person Re-identification with Noisy Labels (LUPerson-NL)

LUPerson-NL Large-Scale Pre-training for Person Re-identification with Noisy Labels (LUPerson-NL) The repository is for our CVPR2022 paper Large-Scale

43 Dec 26, 2022
TeST: Temporal-Stable Thresholding for Semi-supervised Learning

TeST: Temporal-Stable Thresholding for Semi-supervised Learning TeST Illustration Semi-supervised learning (SSL) offers an effective method for large-

Xiong Weiyu 1 Jul 14, 2022
Official PyTorch implementation of "BlendGAN: Implicitly GAN Blending for Arbitrary Stylized Face Generation" (NeurIPS 2021)

BlendGAN: Implicitly GAN Blending for Arbitrary Stylized Face Generation Official PyTorch implementation of the NeurIPS 2021 paper Mingcong Liu, Qiang

onion 462 Dec 29, 2022
Public Implementation of ChIRo from "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations"

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations This directory contains the model architectures and experimental

35 Dec 05, 2022
Exploring Image Deblurring via Blur Kernel Space (CVPR'21)

Exploring Image Deblurring via Encoded Blur Kernel Space About the project We introduce a method to encode the blur operators of an arbitrary dataset

VinAI Research 118 Dec 19, 2022
State of the Art Neural Networks for Deep Learning

pyradox This python library helps you with implementing various state of the art neural networks in a totally customizable fashion using Tensorflow 2

Ritvik Rastogi 60 May 29, 2022
EquiBind: Geometric Deep Learning for Drug Binding Structure Prediction

EquiBind: geometric deep learning for fast predictions of the 3D structure in which a small molecule binds to a protein

Hannes Stärk 355 Jan 03, 2023
PyTorch implementations of the beta divergence loss.

Beta Divergence Loss - PyTorch Implementation This repository contains code for a PyTorch implementation of the beta divergence loss. Dependencies Thi

Billy Carson 7 Nov 09, 2022
This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models are Pix2Pix, Pix2PixHD, CycleGAN and PointWise.

RGB2NIR_Experimental This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models

5 Jan 04, 2023
Pytorch and Torch testing code of CartoonGAN

CartoonGAN-Test-Pytorch-Torch Pytorch and Torch testing code of CartoonGAN [Chen et al., CVPR18]. With the released pretrained models by the authors,

Yijun Li 642 Dec 27, 2022
The Environment I built to study Reinforcement Learning + Pokemon Showdown

pokemon-showdown-rl-environment The Environment I built to study Reinforcement Learning + Pokemon Showdown Been a while since I ran this. Think it is

3 Jan 16, 2022
Unofficial Implementation of Oboe (SIGCOMM'18').

Oboe-Reproduce This is the unofficial implementation of the paper "Oboe: Auto-tuning video ABR algorithms to network conditions, Zahaib Akhtar, Yun Se

Tianchi Huang 13 Nov 04, 2022
Python library for science observations from the James Webb Space Telescope

JWST Calibration Pipeline JWST requires Python 3.7 or above and a C compiler for dependencies. Linux and MacOS platforms are tested and supported. Win

Space Telescope Science Institute 386 Dec 30, 2022
Much faster than SORT(Simple Online and Realtime Tracking), a little worse than SORT

QSORT QSORT(Quick + Simple Online and Realtime Tracking) is a simple online and realtime tracking algorithm for 2D multiple object tracking in video s

Yonghye Kwon 8 Jul 27, 2022
Depression Asisstant GDSC Challenge Solution

Depression Asisstant can help you give solution. Please using Python version 3.9.5 for contribute.

Ananda Rauf 1 Jan 30, 2022