Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.

Overview

Shortformer

This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the WikiText-103 dataset. Read the full paper here.

The Shortformer is a combination of two methods:

  1. Staged Training: We first train the model on short input subsequences and then train it on longer ones. This improves both train speed and evaluation perplexity.
  2. Position-Infused Attention + Caching: We cache previously computed subsequence representations and attend to them using Position-Infused Attention. Position-Infused Attention modifies the model so that position embeddings are not added to the word embeddings at the bottom of the network, but instead, they are added to the keys and queries in the attention sublayer (but not to the values). We show that PIA + caching vastly speeds up generation and also improves perplexity.

Staged training requires no modification to the original code. To see how we implemented the Position-Infused Attention and caching, click here. Implementing PIA and caching is very easy, and we've provided detailed comments in the code to explain what how we did it.

If you use this code or results from our paper, please cite:

@misc{press2020shortformer,
      title={Shortformer: Better Language Modeling using Shorter Inputs}, 
      author={Ofir Press and Noah A. Smith and Mike Lewis},
      year={2020},
      eprint={2012.15832},
}

Requirements and Installation

This repository is a fork of the Fairseq repository and so has the same requirements.

Once you've installed the dependencies, you can install this repository by running:

pip install --editable .

Preparing the data

To download and preprocess the data, run:

cd examples/language_model/
bash prepare-wikitext-103.sh
cd ../..


TEXT=examples/language_model/wikitext-103
python preprocess.py \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

Train/Inference for the different models

Shortformer

Our Shortformer model takes the baseline and adds caching, Position-Infused Attention, and Staged Training.

To train the first stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints128e100/     --arch transformer_lm_wiki103     --max-update 140100 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128 --max-tokens-valid 128 --tokens-from-prev 128 --curriculum 1000 --required-batch-size-multiple 1 --save-interval 100

If your GPUs don't have enough memory to execute that command, you can set --update-freq to 2 and --max-tokens to 4608, or set --update-freq to 3 and --max-tokens to 3072 for running the model with even lower memory constraints. This chunks the batch into 2 or 3 different parts and computes each part seperately (instead of in parallel), so it uses less memory but runs slower.

After that, to train the model with the second (and final) stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir shortformer/ --restore-file checkpoints128e100/checkpoint100.pt     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

Again, you can use the update-freq/max-tokens method from above if you run out of memory.

Saved Checkpoint

If you'd like to download the Shortformer instead of training it, it is available here. Rename that file to checkpoint_best.pt if you'd like to follow the directions below.

Inference

For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1

For token-by-token generation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1 --sliding-inf 1 --context-window 511 --max-tokens 512

(Note that --context-window is a fairseq command and doesn't have the exact meaning that the term "context window" has in our paper.)

Shortformer (without Staged Training)

Staged training improves the perplexity of the model and makes training faster, so there's no reason not to use it, but if you would like to train the Shortformer without it, the command is

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir shortformer-no-st/      --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

For inference, use the same commands as the ones for the Shortformer (above).

Baseline with Staged Training

Our Shortformer model is fast to train and for token-by-token generation, but if speed is not an issue, we can achieve slightly better performance by just applying Staged Training to the Baevski & Auli baseline LM. This model is very slow but achieves the best perplexity.

To train the first stage, download the unmodified fairseq reporsitory and then run:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints-st-128e50/     --arch transformer_lm_wiki103     --max-update 70050 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128  --required-batch-size-multiple 1 --save-interval 50

After that, to train the model with the second (and final) stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir st/ --restore-file checkpoints-st-128e50/checkpoint50.pt     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints

Inference

For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1

For sliding window evaluation of the validation set, with a stride of 2,560, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1 --context-window 2560

Baseline - Baevski & Auli

To train the baseline, download the unmodified fairseq repository and then run:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir baseline/  --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints

Inference

Use the same commands as in the 'Baseline with Staged Training' inference subsection.

Owner
Ofir Press
PhD student @uwnlp
Ofir Press
一个运行在 𝐞𝐥𝐞𝐜𝐕𝟐𝐏 或 𝐪𝐢𝐧𝐠𝐥𝐨𝐧𝐠 等定时面板的签到项目

定时面板上的签到盒 一个运行在 𝐞𝐥𝐞𝐜𝐕𝟐𝐏 或 𝐪𝐢𝐧𝐠𝐥𝐨𝐧𝐠 等定时面板的签到项目 𝐞𝐥𝐞𝐜𝐕𝟐𝐏 𝐪𝐢𝐧𝐠𝐥𝐨𝐧𝐠 特别声明 本仓库发布的脚本及其中涉及的任何解锁和解密分析脚本,仅用于测试和学习研究,禁止用于商业用途,不能保证其合

Leon 1.1k Dec 30, 2022
[CVPR 2021] NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning

NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning Project Page | Paper | Supplemental material #1 | Supplement

KAIST VCLAB 49 Nov 24, 2022
Reinforcement learning algorithms in RLlib

raylab Reinforcement learning algorithms in RLlib and PyTorch. Installation pip install raylab Quickstart Raylab provides agents and environments to b

Ângelo 50 Sep 08, 2022
Codes for our IJCAI21 paper: Dialogue Discourse-Aware Graph Model and Data Augmentation for Meeting Summarization

DDAMS This is the pytorch code for our IJCAI 2021 paper Dialogue Discourse-Aware Graph Model and Data Augmentation for Meeting Summarization [Arxiv Pr

xcfeng 55 Dec 27, 2022
Software & Hardware to do multi color printing with Sharpies

3D Print Colorizer is a combination of 3D printed parts and a Cura plugin which allows anyone with an Ender 3 like 3D printer to produce multi colored

343 Jan 06, 2023
A simple PyTorch Implementation of Generative Adversarial Networks, focusing on anime face drawing.

AnimeGAN A simple PyTorch Implementation of Generative Adversarial Networks, focusing on anime face drawing. Randomly Generated Images The images are

Jie Lei 雷杰 1.2k Jan 03, 2023
《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Rethinking Spatial Dimensions of Vision Transformers Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper NAVER

NAVER AI 224 Dec 27, 2022
Gradient Inversion with Generative Image Prior

Gradient Inversion with Generative Image Prior This repository is an implementation of "Gradient Inversion with Generative Image Prior", accepted to N

MLLab @ Postech 25 Jan 09, 2023
Txt2Xml tool will help you convert from txt COCO format to VOC xml format in Object Detection Problem.

TXT 2 XML All codes assume running from root directory. Please update the sys path at the beginning of the codes before running. Over View Txt2Xml too

Nguyễn Trường Lâu 4 Nov 24, 2022
Rl-quickstart - Reinforcement Learning Quickstart

Reinforcement Learning Quickstart To get setup with the repository, git clone ht

UCLA DataRes 3 Jun 16, 2022
StarGAN - Official PyTorch Implementation (CVPR 2018)

StarGAN - Official PyTorch Implementation ***** New: StarGAN v2 is available at https://github.com/clovaai/stargan-v2 ***** This repository provides t

Yunjey Choi 5.1k Jan 04, 2023
The codes and related files to reproduce the results for Image Similarity Challenge Track 1.

ISC-Track1-Submission The codes and related files to reproduce the results for Image Similarity Challenge Track 1. Required dependencies To begin with

Wenhao Wang 115 Jan 02, 2023
a pytorch implementation of auto-punctuation learned character by character

Learning Auto-Punctuation by Reading Engadget Articles Link to Other of my work 🌟 Deep Learning Notes: A collection of my notes going from basic mult

Ge Yang 137 Nov 09, 2022
for taichi voxel-challange event

Taichi Voxel Challenge Figure: result of python3 example6.py. Please replace the image above (demo.jpg) with yours, so that other people can immediate

Liming Xu 20 Nov 26, 2022
PIGLeT: Language Grounding Through Neuro-Symbolic Interaction in a 3D World [ACL 2021]

piglet PIGLeT: Language Grounding Through Neuro-Symbolic Interaction in a 3D World [ACL 2021] This repo contains code and data for PIGLeT. If you like

Rowan Zellers 51 Oct 08, 2022
Keras Realtime Multi-Person Pose Estimation - Keras version of Realtime Multi-Person Pose Estimation project

This repository has become incompatible with the latest and recommended version of Tensorflow 2.0 Instead of refactoring this code painfully, I create

M Faber 769 Dec 08, 2022
StyleTransfer - Open source style transfer project, based on VGG19

StyleTransfer - Open source style transfer project, based on VGG19

Patrick martins de lima 9 Dec 13, 2021
Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics.

Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics. By Andres Milioto @ University of Bonn. (for the new P

Photogrammetry & Robotics Bonn 314 Dec 30, 2022
Gym environment for FLIPIT: The Game of "Stealthy Takeover"

gym-flipit Gym environment for FLIPIT: The Game of "Stealthy Takeover" invented by Marten van Dijk, Ari Juels, Alina Oprea, and Ronald L. Rivest. Desi

Lisa Oakley 2 Dec 15, 2021
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2.1k Jan 01, 2023