Standalone pre-training recipe with JAX+Flax

Overview

Sabertooth

Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, but this README targets TPU usage.

Contents

  1. Installation
  2. Data preparation
  3. Tokenizer preparation
  4. Pre-training
  5. Fine-tuning on GLUE

Installation

Automatic installation in TPU VMs

The TPU code is not intended to work with the "JAX Cloud TPU Preview" that uses tpu_driver, but rather only with the "new JAX on Cloud TPU in private alpha" that involves direct SSH access. In order to sign up for private alpha access, please follow this link.

Rather than creating VMs manually and installing dependencies by hand, you can use the scripts in tpu_management to help with this.

Before using these scripts:

  • Follow the one-time setup instructions in the "Cloud TPU VM Alpha User Guide" document up to (but not including) the point where you actually create a TPU VM.
  • Then follow the instructions in tpu_management/config.env.template to set configuration variables used by the TPU management scripts.

After that, the following TPU management commands are available:

  • tpu_management/tpu.sh create NAME creates a single-host TPU VM with the specified NAME and prints an ssh config for accessing the VM. Adding this config entry to ~/.ssh/config will then allow you to use ssh NAME to access the VM.
  • tpu_management/tpu.sh provision NAME will push this git repository to the created VM, and build/install all required software within the VM. It also sets up NAME as a git remote, so you can do git push NAME ... to push code to the VM, and git fetch NAME to fetch commits developed within the VM.
  • tpu_management/tpu.sh delete NAME deletes the TPU instance. All files on the TPU VM filesystem will be lost.
  • tpu_management/tpu.sh config-ssh NAME prints the SSH config you need to connect to the VM. This command is useful when aVM automatically restarts and gets assigned a new IP address.

The tpu_management/pod.sh behaves similarly, but for multi-host TPU training:

  • tpu_management/pod.sh create NAME creates a multi-host TPU setup with specified NAME, and sets up ssh config for each worker (${NAME}0, ${NAME}1, ...).
  • tpu_management/pod.sh provision NAME will configure all worker VMs, as well as setting up a git such that git push NAME ... will push to all four workers
  • tpu_management/tpu.sh delete NAME deletes the TPU instance, including all workers. All files on the worker VM filesystems will be lost.
  • tpu_management/tpu.sh config-ssh NAME prints the SSH config you need to connect to the workers. This command is useful when a VM automatically restarts and gets assigned a new IP address.

Note: on a brand new TPU VM, doing import tensorflow for the first time can take minutes. This issue goes away on all subsequent calls. If a pre-training/fine-tuning script appears to hang on a new VM, this is probably the cause. You can test this by waiting out python3 -c "import tensorflow" and then running the actual script you want.

Manual installation

For Python dependencies, see requirements_tpu.txt. The input pipeline and data processing scripts are implemented in Rust. See https://rustup.rs/ for one-line shell command that installs Rust.

Run ./install_sabertooth_pipeline.sh to build and install the sabertooth_pipeline helper package into the currently active python environment. This requires CMake to be installed.

Data preparation

Accepted formats

Sabertooth accepts pre-training data in either of the following formats:

  • Text format with one sentence per line, with blank lines in between documents (the BERT format). With this format, sentence segmentation is assumed to have been fully carried out during pre-processing.
  • JSONlines format, which is automatically used for all files ending in .jsonl. A zstandard-compressed version is also accepted, which is used for all files ending in .jsonl.zst or plain .zst. Each line match the schema {"text": "[JSON-encoded text of an entire document...]"} (the Pile format). With this format, sentence segmentation will be performed at training time by background CPU threads that are also responsible for tokenization.

Pre-processed downloads

If you don't want to run the processing commands described in the next subsection, you can skip ahead by downloading already processed data from one of the following sources:

  • English Wikipedia in BERT format (4.5GB compressed download), courtesy of GluonNLP. The processing scripts in this repository are largely identical to the GluonNLP processing scripts, except that ours are implemented in Rust. The only downside of this download is that it does not include any Books data (potentially resulting in a lower GLUE score after pre-training), and does not shuffle the full corpus.
  • The Pile: 825GB of text from diverse sources, but some of this data is not quite as clean as our Wikipedia processing. Training BERT-base on the Pile consistently achieves a GLUE test score above 77, but we have not nailed the right set of hyperparameters for matching the effectiveness of Wiki+Books data.

Preparing wikibooks data

Downloading and extracting: Wikipedia

First, download the Wikipedia dump and extract the pages. The Wikipedia dump can be downloaded from this link, and should contain the following file: enwiki-latest-pages-articles-multistream.xml.bz2. Note that older database dumps are periodically deleted from the official Wikimedia downloads website, so we can't pin a specific version of Wikipedia without hosting a mirror of the full dataset.

Next, run WikiExtractor (rust/create_pretraining_data/WikiExtractor.py) to extract the wiki pages from the XML. The generated wiki pages file will be stored as <data dir>/LL/wiki_nn; for example <data dir>/AA/wiki_00. Each file is ~1MB, and each sub directory has 100 files from wiki_00 to wiki_99, except the last sub directory. For the dump around December 2020, the last file is FL/wiki_09.

DATA_ROOT="$HOME/prep_sabertooth"
mkdir -p $DATA_ROOT
cd $DATA_ROOT
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles-multistream.xml.bz2    # Optionally use curl instead
bzip2 -d enwiki-latest-pages-articles-multistream.xml.bz2
python3 rust/create_pretraining_data/WikiExtractor.py enwiki-latest-pages-articles-multistream.xml    # Results are placed in text/

Downloading and extracting: Books

Download and extract books1.

Preprocessing and sharding

DATA_ROOT="$HOME/prep_sabertooth"
mkdir -p $DATA_ROOT/wikibooks
cd rust/create_pretraining_data
cargo build --release  # Removing `--release` will make the code *much* slower
target/release/create_pretraining_data --output $DATA_ROOT/wikibooks --num-shards 500 --wiki $DATA_ROOT/'text/??/wiki_??' --books $DATA_ROOT/'books1/epubtxt/*.txt'

This will write 500 shards in JSONlines format to the folder wikibooks/; each should be around 40MB in size. Set the RAYON_NUM_THREADS environment variable to limit the number of parallel threads used for processing; by default, a thread will be spawned per CPU core.

create_pretraining_data will load the full dataset into RAM to perform a global shuffle. If you run out of memory, you can it multiple times with a subset of the files. --wiki and --books both accept glob patterns, so you can do e.g. --wiki 'text/A?/wiki_??'. The --wiki or --books can also be repeated (with different files each time), or omitted. The JSONlines format should also make it easy to manually concatenate/split/shuffle shards of data.

Compressing the data

Compressing the data is optional, but it can save space when storing and transferring large datasets. All of our code supports uncompressed .jsonl and zstandard-compressed .jsonl.zst files interchangeably, so you can also skip this step (and change .jsonl.zst to .jsonl in all commands below).

To compress: download zstandard, compile it with make, and run a compression command such as zstd $DATA_ROOT/wikibooks/*.jsonl.

The overhead of decompressing data is negligible when compared to more costly operations like tokenization or sentence segmentation, which we will be doing at training time anyway. The only downside of compressing the data is that you can't inspect compressed files with a simple text viewer/editor.

Tokenizer preparation

We use SentencePiece for tokenization.

For large datasets, feeding the raw data directly to the sentencepiece trainer is very slow in memory-hungry (RAM usage is at least 10x the size of the uncompressed raw text, and there are single-threaded O(n) sections in the sentencepiece trainer code).

Instead, we will first count unique words in our data, separated only by whitespace. We will then pass a TSV file containing the counts to the sentencepiece trainer, which will further decompose them and build a subword vocabulary.

To count the tokens in our data, run:

DATA_ROOT="$HOME/prep_sabertooth"
pushd rust/count_tokens
cargo build --release
popd
rust/count_tokens/target/release/count_tokens $DATA_ROOT/wikibooks/*.jsonl.zst > $DATA_ROOT/counts_wikibooks.tsv

count_tokens accepts filenames as arguments, and the supports the following formats:

  • .jsonl: JSONlines data, where each line matches the schema {"text": "[JSON-encoded text of an entire document...]"}
  • .jsonl.zst and .zst: same as above, but zstandard-compressed
  • .tsv: merges counts from another tsv file into the output. If your available CPU RAM is not sufficient to count your corpus all at once, you can count shards of the data separately and then merge the tsv files
  • All other extensions are treated as plain-text data

Once the TSV file is created, use our provided script to train the tokenizer:

python3 rust/count_tokens/train_tokenizer.py --input $DATA_ROOT/counts_wikibooks.tsv --model_prefix $DATA_ROOT/wikibooks_32k --vocab_size 32128

After this you're ready to run pre-training!

Tip for data transfer: The data folder (with zstandard-compressed shards) can be packed into a tar archive using the command cd $DATA_DIR/.. && tar cvf prep_sabertooth.tar prep_sabertooth/wikibooks/*.jsonl.zst prep_sabertooth/*.model prep_sabertooth/*.vocab. If you have multiple TPU VMs (either multiple workers for multi-host training, or just multiple VMs for different jobs), scp-ing data from one VM to another is 10x-100x faster than copying via gsutil cp or from any non-gcloud machine. SSH into a TPU VM and run ifconfig to determine its interal IP address that can be accessed by other VMs (typically of the form 10.x.y.z).

Pre-training

For single-host pre-training, run a command such as:

python3 run_pretraining.py --config=configs/pretraining.py --config.train_batch_size=1024 --config.optimizer="adam" --config.learning_rate=1e-4 --config.num_train_steps=1000000 --config.num_warmup_steps=10000 --config.max_seq_length=128 --config.max_predictions_per_seq=20

For BERT base, our best hyperparameter setting thus far involves pre-training with global batch size 4096. To launch this training on TPUv3-32 with 4 hosts, set up each host with the required environment variables and then run the following command:

python3 run_pretraining.py --config=configs/pretraining.py --config.optimizer="adam" --config.train_batch_size=4096 --config.learning_rate=1e-3 --config.num_train_steps=125000 --config.num_warmup_steps=3125 --config.adam_epsilon=1e-11 --config.adam_beta1=0.9 --config.adam_beta2=0.98 --config.weight_decay=0.1 --config.max_grad_norm=0.4

Use --config.input_files and --config.tokenizer to configure dataset and tokenizer paths for pre-training (see configs/pretraining.py for the full set of configuration options and hyperparameters.)

Pre-training notes

Our pre-training recipe is close to BERT, but there are a few differences:

  • We use a SentencePiece unigram tokenizer, instead of WordPiece
  • The next-sentence prediction (NSP) task from BERT is replaced with a sentence order prediction (SOP) from ALBERT. We do this primarily to simplify the data pipeline implementation, but past work has observed SOP to give better results than NSP.
  • BERT's Adam optimizer departs from the Adam paper in that it omits bias correction terms. This codebase uses Flax's implementation of Adam, which includes bias correction.
  • Pre-training uses a fixed maximum sequence length of 128, and does not increase the sequence length to 512 for the last 10% of training.
  • The wiki+books data used in this repository is designed to match the BERT paper as closely as possible, but it's not identical. The data used by BERT was never publicly available, so most BERT replications have this property.
  • Random masking and sentence shuffling occurs each time a batch of examples is sampled during training, rather than a single time during the data generation step.

Fine-tuning on GLUE

Sample command for fine-tuning on GLUE:

python3 run_classifier.py --config=configs/classifier.py --config.init_checkpoint="/path/to/checkpoint/folder/" --config.dataset_name="cola" --config.learning_rate="5e-5"

The dataset_name should be one of: cola, mrpc, qqp, sst2, stsb, mnli, qnli, rte. WNLI is not supported because BERT accuracy on WNLI is below the baseline, unless a special training recipe is used.

Leaderboard evaluation

To evaluate a model on GLUE, we typically run a sweep across different learning rates and use the development set to select the best one:

OUTPUT_DIR="$HOME/glue"
./sweep_glue.sh /path/to/checkpoint/folder $OUTPUT_DIR 5e-5 4e-5 3e-5 2e-5

Here are the results from one such sweep, using a "base" size model trained with batch size 1024 for 1M steps (see the single-host training command above). Our understanding is that this option most closely approximates the total compute resources used to train the original BERT-base, except that we do not increase the sequence length to 512 at any point during training. The learning rates the sweep found are random to some extent, since the sweep doubles as both a learning rate search and a chance to try different random seeds.

CoLA SST-2 MRPC (f1/a) STS-B (p/s) QQP (f1/acc) MNLI (m/mm) QNLI RTE
dev 56.5 91.9 90.8 / 87.3 88.3 / 88.4 87.0 / 90.4 84.9 / 85.5 92.1 70.4
test 57.0 92.5 87.3 / 82.8 87.8 / 87.0 71.3 / 89.1 85.1 / 84.3 92.1 66.7
lr 3e-5 3e-5 3e-5 5e-5 3e-5 2e-5 3e-5 5e-5

Here are the results from another "base" size model, this time trained with batch size 4096 for 125K steps (see the multi-host training command above). Note that model only sees half the number of examples compared to the single-host training command above.

CoLA SST-2 MRPC (f1/a) STS-B (p/s) QQP (f1/acc) MNLI (m/mm) QNLI RTE
dev 60.3 91.9 89.6 / 85.5 87.9 / 88.1 87.2 / 90.6 84.7 / 85.0 91.7 68.6
test 51.9 92.1 88.8 / 84.7 86.2 / 85.0 70.9 / 89.0 84.5 / 84.0 91.5 65.9
lr 4e-5 4e-5 5e-5 2e-5 4e-5 3e-5 3e-5 5e-5
Owner
Nikita Kitaev
Nikita Kitaev
Tools for investing in Python

InvestOps Original repository on GitHub Original author is Magnus Erik Hvass Pedersen Introduction This is a Python package with simple and effective

24 Nov 26, 2022
Additional environments compatible with OpenAI gym

Decentralized Control of Quadrotor Swarms with End-to-end Deep Reinforcement Learning A codebase for training reinforcement learning policies for quad

Zhehui Huang 40 Dec 06, 2022
QQ Browser 2021 AI Algorithm Competition Track 1 1st Place Program

QQ Browser 2021 AI Algorithm Competition Track 1 1st Place Program

249 Jan 03, 2023
Feedback is important: response-aware feedback mechanism for background based conversation

RFM The code for the paper: "Feedback is important: response-aware feedback mechanism for background based conversation." Requirements python 3.7 pyto

Jiatao Chen 2 Sep 29, 2022
End-to-end face detection, cropping, norm estimation, and landmark detection in a single onnx model

onnx-facial-lmk-detector End-to-end face detection, cropping, norm estimation, and landmark detection in a single onnx model, model.onnx. Demo You can

atksh 42 Dec 30, 2022
tinykernel - A minimal Python kernel so you can run Python in your Python

tinykernel - A minimal Python kernel so you can run Python in your Python

fast.ai 37 Dec 02, 2022
Crab is a flexible, fast recommender engine for Python that integrates classic information filtering recommendation algorithms in the world of scientific Python packages (numpy, scipy, matplotlib).

Crab - A Recommendation Engine library for Python Crab is a flexible, fast recommender engine for Python that integrates classic information filtering r

python-recsys 1.2k Dec 21, 2022
Multi agent DDPG algorithm written in Python + Pytorch

Multi agent DDPG algorithm written in Python + Pytorch. It also includes a Jupyter notebook, Tennis.ipynb, as a showcase.

Rogier Wachters 2 Feb 26, 2022
Research code of ICCV 2021 paper "Mesh Graphormer"

MeshGraphormer ✨ ✨ This is our research code of Mesh Graphormer. Mesh Graphormer is a new transformer-based method for human pose and mesh reconsructi

Microsoft 251 Jan 08, 2023
Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

Mikhail Grankin 768 Jan 01, 2023
The world's largest toxicity dataset.

The Toxicity Dataset by Surge AI Saving the internet is fun. Combing through thousands of online comments to build a toxicity dataset isn't. That's wh

Surge AI 134 Dec 19, 2022
Scales, Chords, and Cadences: Practical Music Theory for MIR Researchers

ISMIR-musicTheoryTutorial This repository has slides and Jupyter notebooks for the ISMIR 2021 tutorial Scales, Chords, and Cadences: Practical Music T

Johanna Devaney 58 Oct 11, 2022
The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv SAINT: Improved Neural Networks for Tabular Data via Row Atte

Gowthami Somepalli 284 Dec 21, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
SemiNAS: Semi-Supervised Neural Architecture Search

SemiNAS: Semi-Supervised Neural Architecture Search This repository contains the code used for Semi-Supervised Neural Architecture Search, by Renqian

Renqian Luo 21 Aug 31, 2022
Industrial Image Anomaly Localization Based on Gaussian Clustering of Pre-trained Feature

Industrial Image Anomaly Localization Based on Gaussian Clustering of Pre-trained Feature Q. Wan, L. Gao, X. Li and L. Wen, "Industrial Image Anomaly

smiler 6 Dec 25, 2022
A Broader Picture of Random-walk Based Graph Embedding

Random-walk Embedding Framework This repository is a reference implementation of the random-walk embedding framework as described in the paper: A Broa

Zexi Huang 23 Dec 13, 2022
IEEE-CIS Technical Challenge on Predict+Optimize for Renewable Energy Scheduling

IEEE-CIS Technical Challenge on Predict+Optimize for Renewable Energy Scheduling This is my code, data and approach for the IEEE-CIS Technical Challen

3 Sep 18, 2022
A Python reference implementation of the CF data model

cfdm A Python reference implementation of the CF data model. References Compliance with FAIR principles Documentation https://ncas-cms.github.io/cfdm

NCAS CMS 25 Dec 13, 2022