Complete the code of prefix-tuning in low data setting

Overview

Prefix Tuning

Note:

作者在论文中提到使用真实的word去初始化prefix的操作(Initializing the prefix with activations of real words,significantly improves generation)。我在使用作者提供的代码时遇到了一些问题,因此按照代码的思路添加了利用真实词汇进行初始化的内容。

可以采用以下的方式运行:

Train

cd seq2seq; 

python train_bart.py --mode xsum --preseqlen 200 --do_train yes --fp16 yes --bsz 16  --epoch 30  --gradient_accumulation_step 3 --learning_rate 0.00005  --mid_dim 800 --use_lowdata_token 'yes' --lowdata_token 'summarize'

其中use_lowdata_token表示是否采用real word初始化的方式;lowdata_token表示传入的real word.

Decode

cd seq2seq; 

python train_bart.py --mode xsum --do_train no --prefix_model_path {checkpoint_path} --preseqlen {same as training} --mid_dim {same as training} --use_lowdata_token 'yes' --lowdata_token 'summarize'

Files:

.
├── gpt2                          # Code for GPT2 style autoregressive LM
│   ├── train_e2e.py              # high-level scripts to train.
│   ├── train_control.py          # code that implements prefix-tuning.
│   ├── trainer_prefix.py         # trainer code for the training loop. 
│   ├── run_language_modeling.py  # training code (contains data loading, model loading, and calls trainer)
│   ├── gen.py                    # high-level scripts to decode. 
│   └── run_generation.py         # decoding code. 
│
├── seq2seq                       # Code for encoder-decoder architecture
│   ├── train_bart.py             # high-level scripts to train.
│   ├── prefixTuning.py           # code that implements prefix-tuning.
│   ├── finetune.py               # training code (contains data loading, model loading, and calls trainer)   
│   ├── lightning_base.py         # helper code
│   ├── utils.py                  # helper code
│   └── callbacks.py              # helper code
└── ...

To run the code for GPT2 style autoregressive LM, the code is in gpt2/. This corresponds to the table-to-text experiments in the paper.

To run the code for encoder-decoder architecture like BART, the code is in seq2seq. This corresponds to the summarization experiments in the paper.

The two primary scripts I used to run my codes are gpt2/train_e2e.py (for table-to-text) and seq2seq/train_bart.py(for summarization). they are set to default of good hyperparameters, and can be used to tune hyperparameter :)


Setup:

cd transformer; pip install -e .


Train via prefix-tuning:

cd gpt2;

python train_e2e.py --optim_prefix yes --preseqlen 5 --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
cd seq2seq; 

python train_bart.py --mode xsum --preseqlen 200 --do_train yes --fp16 yes --bsz 16  --epoch 30  --gradient_accumulation_step 3 --learning_rate 0.00005  --mid_dim 800

Other baseline approaches

cd gpt2;

python train_e2e.py --tuning_mode {finetune/adaptertune} --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
cd seq2seq;

python train_e2e.py --tuning_mode finetune --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101

Decode:

cd gpt2;

python gen.py {data2text/webnlg/...} yes test {checkpoint_path} no
cd seq2seq; 

python train_bart.py --mode xsum --do_train no --prefix_model_path {checkpoint_path} --preseqlen {same as training} --mid_dim {same as training}

For details of the methods and results, please refer to our paper.

@misc{li2021prefixtuning,
      title={Prefix-Tuning: Optimizing Continuous Prompts for Generation}, 
      author={Xiang Lisa Li and Percy Liang},
      year={2021},
      eprint={2101.00190},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
Owner
Andrew Zeng
Andrew Zeng
Andrew Zeng
Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks.

Heterogeneous Graph Benchmark Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks. Roadmap We organize our repo by task, and on

THUDM 176 Dec 17, 2022
A hyperparameter optimization framework

Optuna: A hyperparameter optimization framework Website | Docs | Install Guide | Tutorial Optuna is an automatic hyperparameter optimization software

7.4k Jan 04, 2023
Convert weight file.pth to weight file.blob

CONVERT YOUR MODEL TO IR FORMAT INSTALLATION OpenVino Toolkit Download openvinotoolkit 2021.3 version : Link Instruction of installation : Link Pytorc

Tran Anh Tuan 3 Nov 18, 2021
OpenVINO黑客松比赛项目

Window_Guard OpenVINO黑客松比赛项目 英文名称:Window_Guard 中文名称:窗口卫士 硬件 树莓派4B 8G版本 一个磁石开关 USB摄像头(MP4视频文件也可以) 软件(库) OpenVINO RPi 使用方法 本项目使用的OPenVINO是是2021.3版本,并使用了

Tango 6 Jul 04, 2021
Seeing if I can put together an interactive version of 3b1b's Manim in Streamlit

streamlit-manim Seeing if I can put together an interactive version of 3b1b's Manim in Streamlit Installation I had to install pango with sudo apt-get

Adrien Treuille 6 Aug 03, 2022
DeconvNet : Learning Deconvolution Network for Semantic Segmentation

DeconvNet: Learning Deconvolution Network for Semantic Segmentation Created by Hyeonwoo Noh, Seunghoon Hong and Bohyung Han at POSTECH Acknowledgement

Hyeonwoo Noh 325 Oct 20, 2022
SpanNER: Named EntityRe-/Recognition as Span Prediction

SpanNER: Named EntityRe-/Recognition as Span Prediction Overview | Demo | Installation | Preprocessing | Prepare Models | Running | System Combination

NeuLab 104 Dec 17, 2022
Deepparse is a state-of-the-art library for parsing multinational street addresses using deep learning

Here is deepparse. Deepparse is a state-of-the-art library for parsing multinational street addresses using deep learning. Use deepparse to Use the pr

GRAAL/GRAIL 192 Dec 20, 2022
Code for ICML 2021 paper: How could Neural Networks understand Programs?

OSCAR This repository contains the source code of our ICML 2021 paper How could Neural Networks understand Programs?. Environment Run following comman

Dinglan Peng 115 Dec 17, 2022
Pytorch implementation of paper Semi-supervised Knowledge Transfer for Deep Learning from Private Training Data

Pytorch implementation of paper Semi-supervised Knowledge Transfer for Deep Learning from Private Training Data

Hrishikesh Kamath 31 Nov 20, 2022
Aerial Imagery dataset for fire detection: classification and segmentation (Unmanned Aerial Vehicle (UAV))

Aerial Imagery dataset for fire detection: classification and segmentation using Unmanned Aerial Vehicle (UAV) Title FLAME (Fire Luminosity Airborne-b

79 Jan 06, 2023
Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment"

DSN-IQA Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment" Requirements Python =3.8.0 Pytorch =1.7.1 Usage wit

7 Oct 13, 2022
Online Multi-Granularity Distillation for GAN Compression (ICCV2021)

Online Multi-Granularity Distillation for GAN Compression (ICCV2021) This repository contains the pytorch codes and trained models described in the IC

Bytedance Inc. 299 Dec 16, 2022
A Home Assistant custom component for Lobe. Lobe is an AI tool that can classify images.

Lobe This is a Home Assistant custom component for Lobe. Lobe is an AI tool that can classify images. This component lets you easily use an exported m

Kendell R 4 Feb 28, 2022
Model serving at scale

Run inference at scale Cortex is an open source platform for large-scale machine learning inference workloads. Workloads Realtime APIs - respond to pr

Cortex Labs 7.9k Jan 06, 2023
Incorporating Transformer and LSTM to Kalman Filter with EM algorithm

Deep learning based state estimation: incorporating Transformer and LSTM to Kalman Filter with EM algorithm Overview Kalman Filter requires the true p

zshicode 57 Dec 27, 2022
A curated list of awesome resources combining Transformers with Neural Architecture Search

A curated list of awesome resources combining Transformers with Neural Architecture Search

Yash Mehta 173 Jan 03, 2023
My 1st place solution at Kaggle Hotel-ID 2021

1st place solution at Kaggle Hotel-ID My 1st place solution at Kaggle Hotel-ID to Combat Human Trafficking 2021. https://www.kaggle.com/c/hotel-id-202

Kohei Ozaki 18 Aug 19, 2022
Official Implementation of "Transformers Can Do Bayesian Inference"

Official Code for the Paper "Transformers Can Do Bayesian Inference" We train Transformers to do Bayesian Prediction on novel datasets for a large var

AutoML-Freiburg-Hannover 103 Dec 25, 2022
Reinforcement Learning for the Blackjack

Reinforcement Learning for Blackjack Author: ZHA Mengyue Math Department of HKUST Problem Statement We study playing Blackjack by reinforcement learni

Dolores 3 Jan 24, 2022