Tutorial to pretrain & fine-tune a 🤗 Flax T5 model on a TPUv3-8 with GCP

Overview

Pretrain and Fine-tune a T5 model with Flax on GCP

This tutorial details how pretrain and fine-tune a FlaxT5 model from HuggingFace using a TPU VM available on Google Cloud.

While the code is only slightly adapted from the original HuggingFace examples for pretraining and seq2seq fine-tuning, this repository is aimed to provide a comprehensive overview for the whole process, with a special focus on pitfalls due to an incorrect environment setup.

Why JAX/Flax? Thanks to the amazing work of the HuggingFace team, a good portion of the models available in the transformers library are available in Flax, a neural network library built on top of JAX. Using Flax for pretraining on TPUs is especially convenient for two reasons:

  • TPUs are optimized for training deep models, and Flax is optimized for TPUs. A benchmark conducted by Huggingface showed that a BERT pretraining with Flax takes only 65% of the time needed with Pytorch/XLA to achieve comparable performances on a TPUv3-8 VM, and only 35% of the time if compared to Pytorch training on 8 V100 GPUs. Since language models are usually pretrained for days on massive amounts of text, the speedup is especially desirable here.

  • Out-of-the-box compatibility with Pytorch and Tensorflow in the transformers framework. A Flax model can be easily converted in Pytorch, for example, by using T5ForConditionalGeneration.from_pretrained("path/to/flax/ckpt", from_flax=True).

The code and instructions contained in this repository were used to pretrain the models gsarti/t5-base-it and gsarti/t5-large-it available on the Huggingface Hub, using ~270Gb of cleaned web-scraped Italian texts. The dataset is also made available on the Huggingface Hub under the name gsarti/clean_mc4_it.

Setup on your machine

Follow the instructions detailed in the Cloud SDK Install Guide to install the gcloud client on your system.

TPU VMs only have ~100Gb of disk space, which is highly unlikely to be enough to store your raw + preprocessed dataset and all model checkpoints. GCP gives a choice among different kind of disk types and limits the total disk space a user can create. At the time of writing this, for a non-free trial user in Europe without special access the limit is 250Gb for SSD (pd-balanced) and 2Tb overall, including SSD and HDD (pd-standard).

### Define your variables
export GCP_PROJECT="<YOUR_PROJECT_NAME>"
export GCP_ZONE="<YOUR_REGION>"
export GCP_TPU_NAME="<YOUR_TPU_NAME>"

# >>>>> Run this part to add a disk to your TPU VM
export GCP_DISK_NAME="<YOUR_DISK_NAME>"
export GCP_DISK_SIZE_GB=1200
export GCP_DISK_TYPE=pd-standard


gcloud beta compute disks create $GCP_DISK_NAME \
    --project=$GCP_PROJECT \
    --type=$GCP_DISK_TYPE \
    --size="${GCP_DISK_SIZE_GB}GB" \
    --zone=$GCP_ZONE
# <<<<<

# Create the TPU VM
gcloud alpha compute tpus tpu-vm create $GCP_TPU_NAME \
    --zone $GCP_ZONE \
    --project $GCP_PROJECT \
    --accelerator-type v3-8 \
    --version v2-alpha \
    # Uncomment this if a disk is used
    #--data-disk source="projects/${GCP_PROJECT}/zones/${GCP_ZONE}/disks/${GCP_DISK_NAME}" 

Inside the TPUv3-8 machine

Log in inside the machine with the following command:

gcloud alpha compute tpus tpu-vm ssh $GCP_TPU_NAME --zone $GCP_ZONE --project $GCP_PROJECT

The system is a classic Ubuntu 20.04 with some utility libraries already available (e.g. tmux, vim). The following commands show how to setup the disk inside the machine, if you created a disk. The disk in the example is available on /dev/sdb, and we mount it on the data folder in the home:

# Check that your disk is visible and get its name
lsblk
# Mount disk in the data folder
sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb
sudo mkdir -p data
sudo mount -o discard,defaults /dev/sdb data
sudo chmod a+w data
# Optionally, add the automatic mounting: https://cloud.google.com/compute/docs/disks/add-persistent-disk#configuring_automatic_mounting_on_vm_restart

# Permanently export the disk as the new path for HF dataset caching
echo "export HF_DATASETS_CACHE=data" >> .bashrc
source .bashrc

We then setup the environment with required libraries and dependencies:

### Setup the TPUv3-8 environment
sudo apt update
sudo apt-get install python3.8-venv
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
sudo apt-get install git-lfs
git lfs install

The next step is to create a repository on the Huggingface Hub (here e.g. t5-base-it) that will contain all our files. We then train a tokenizer, select a config, and push the two JSON files in the repository we created.

Depending on the number of sentences that need to be processed, this may take quite some time (~2 hours for 10M sentences).

# Huggingface Hub variables
export HF_NAME="<YOUR_HF_HUB_USERNAME>"
export HF_PWD="<YOUR_HF_HUB_PASSWORD>"
export HF_TOKEN="<YOUR_HF_HUB_API_TOKEN>"
export HF_PROJECT="t5-base-it"

# Variables for training the tokenizer and creating the config
export VOCAB_SIZE="32000"
export N_INPUT_SENTENCES="1000000" # Num of sentences to train the tokenizer
export DATASET="gsarti/clean_mc4_it" # Name of the dataset in the Huggingface Hub
export DATASET_CONFIG="full" # Config of the dataset in the Huggingface Hub 
export DATASET_SPLIT="train" # Split to use for training tokenizer and model
export TEXT_FIELD="text" # Field containing the text to be used for training
export CONFIG_TYPE="google/t5-v1_1-base" # Config that our model will use
export MODEL_PATH="data/${HF_PROJECT}" # Path to the model, e.g. here inside the mount

### Setup the project on the HF Hub
#huggingface-cli login also works
echo $HF_TOKEN > ~/.huggingface/token
huggingface-cli repo create $HF_PROJECT
git clone "https://huggingface.co/${HF_NAME}/${HF_PROJECT}"
mv $HF_PROJECT $MODEL_PATH

# Create the tokenizer and the config
python create_tokenizer_cfg.py \
    --model_dir $MODEL_PATH \
    --dataset $DATASET \
    --dataset_config $DATASET_CONFIG \
    --dataset_split $DATASET_SPLIT \
    --text_field $TEXT_FIELD \
    --vocab_size $VOCAB_SIZE \
    --input_sentence_size $N_INPUT_SENTENCES \
    --config_type $CONFIG_TYPE

# Push the tokenizer and the config to the HF Hub
cd $MODEL_PATH
git lfs track "*tfevents*"
git add .
git commit -m "Added tokenizer and config"à
# If this doesn't work, do a normal push and insert your credentials
git push "https://huggingface.co/${HF_NAME}/${HF_NAME}:${HF_PWD}@${HF_PROJECT}"
cd ..

The last step is to run the pretraining. The following command will train the model on the same dataset used to train the tokenizer, logging train/eval metrics to Tensorboard and pushing model checkpoints and artifacts to the Hub every save_steps steps:

Important: Since checkpoints are stored with Git LFS on the Hub, depending on the number of save_steps you may encur in problems with space (e.g. training a model with 1Gb ckpts for 1M steps with save_steps=10_000 will need 100Gb of free space to avoid crashes during training). Two ways to avoid problems are:

  • Use save_steps > (tot_steps / your_available_memory) * ckpt_size_gb. E.g. if we are doing 1M steps and the ckpt_size is 1GB and we only have 76Gb free, we will need save_steps higher than 1M/76Gb ~= 13158. Failing to do so will result in an interruption of the training once the memory has been exceeded.

  • Use git lfs prune --verify-remote to periodically remove old cached checkpoints from .git/lfs/objects and free up some space. If done often enough this will prevent any problem of space, but requires manual intervention of the user.

# Run the training
# Adjust the parameters to your needs
# To avoid using an extra disk for storing the model, 
python run_t5_mlm_flax.py \
    --output_dir=$MODEL_PATH \
    --model_type="t5" \
    --config_name=$MODEL_PATH \
    --tokenizer_name=$MODEL_PATH \
    --preprocessing_num_workers="96" \
    --do_train --do_eval \
    --adafactor \
    --dataset_name="gsarti/clean_mc4_it" \
    --dataset_config_name="full" \
    --max_seq_length="512" \
    --per_device_train_batch_size="8" \
    --per_device_eval_batch_size="8" \
    --learning_rate="0.005" \
    --overwrite_output_dir \
    --num_train_epochs="1" \
    --logging_steps="500" \
    --save_steps="80000" \
    --eval_steps="2500" \
    --weight_decay="0.01" \
    --warmup_steps="10000" \
    --validation_split_count="15000" \
    --push_to_hub \
    #--gradient_accumulation_steps="2" # Uncomment to use gradient accumulation
    #--resume_from_checkpoint=$MODEL_PATH # Uncomment to resume from ckpt

Exporting Checkpoints

Your model is now trained, congratulations! 🎉 Now you may want to export your final Flax model to other frameworks. This can be done easily:

python export_checkpoint.py --model_dir $MODEL_PATH
cd $MODEL_PATH
git commit -m "Added TF and PT models"
git push

Fine-tune a pretrained T5 model in Flax

TODO

Useful Tips

About GCP Billing Except for TPU costs, which can be offset by taking part in the TRC program, the second highest billing voice is the uploading of files from the VM to another destination. This happens every save_step during training, to keep the Tensorboard and the checkpoint on the HF Hub in sync. For in-region uploads, at the time of writing this the cost is up to 0.12 USD per GB. This means that for a full pretraining of 1M steps with logging every 10k steps, with checkpoints of roughly 1GB each, the cost is roughly 12$USD. To reduce this cost simply increase the amount of save_steps in the run_t5_mlm_flax.py script.

About data preprocessing Imagine you've trained a base model and now you want to train a large variant, using the same tokenzer. Even if the tokenizer is the same, if the path of the tokenizer is changed the cache will be invalidated and everything will be recomputed. If this is the case, it's not a tragedy: some extra preprocessing time will be required. But mind that the tokenization + grouping take substantial space on disk, so if the space is not enough, you will need to delete the arrows from the previous HuggingFace dataset cache. You can easily do so with something like:

find . -type f -size +2G -size -4G -exec ls -lah {} + | grep 'Aug 25 13' | wc -l

As an example, knowing that my tokenization completed Aug 25 at 13:xx and that files in the cache are 3Gb in size each, this should return the value you set for n_proc. You can get the filenames of the old preprocessing caches by removing the wc -l pipe and delete them all to leave space for the new ones.

Owner
Gabriele Sarti
PhD Student in NLP @ University of Groningen | Prev: Aindo, ILC-CNR Pisa, Skytech Montreal
Gabriele Sarti
A paper list of pre-trained language models (PLMs).

Large-scale pre-trained language models (PLMs) such as BERT and GPT have achieved great success and become a milestone in NLP.

RUCAIBox 124 Jan 02, 2023
KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark.

KLUE Baseline Korean(한국어) KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark. See our paper fo

74 Dec 13, 2022
Super easy library for BERT based NLP models

Fast-Bert New - Learning Rate Finder for Text Classification Training (borrowed with thanks from https://github.com/davidtvs/pytorch-lr-finder) Suppor

Utterworks 1.8k Dec 27, 2022
Sapiens is a human antibody language model based on BERT.

Sapiens: Human antibody language model ____ _ / ___| __ _ _ __ (_) ___ _ __ ___ \___ \ / _` | '_ \| |/ _ \ '

Merck Sharp & Dohme Corp. a subsidiary of Merck & Co., Inc. 13 Nov 20, 2022
The entmax mapping and its loss, a family of sparse softmax alternatives.

entmax This package provides a pytorch implementation of entmax and entmax losses: a sparse family of probability mappings and corresponding loss func

DeepSPIN 330 Dec 22, 2022
TruthfulQA: Measuring How Models Imitate Human Falsehoods

TruthfulQA: Measuring How Models Imitate Human Falsehoods

69 Dec 25, 2022
A repo for materials relating to the tutorial of CS-332 NLP

CS-332-NLP A repo for materials relating to the tutorial of CS-332 NLP Contents Tutorial 1: Introduction Corpus Regular expression Tokenization Tutori

Alok singh 9 Feb 15, 2022
Line as a Visual Sentence: Context-aware Line Descriptor for Visual Localization

Line as a Visual Sentence with LineTR This repository contains the inference code, pretrained model, and demo scripts of the following paper. It suppo

SungHo Yoon 158 Dec 27, 2022
Two-stage text summarization with BERT and BART

Two-Stage Text Summarization Description We experiment with a 2-stage summarization model on CNN/DailyMail dataset that combines the ability to filter

Yukai Yang (Alexis) 6 Oct 22, 2022
Milaan Parmar / Милан пармар / _米兰 帕尔马 170 Dec 13, 2022
ReCoin - Restoring our environment and businesses in parallel

Shashank Ojha, Sabrina Button, Abdellah Ghassel, Joshua Gonzales "Reduce Reuse R

sabrina button 1 Mar 14, 2022
Wrapper to display a script output or a text file content on the desktop in sway or other wlroots-based compositors

nwg-wrapper This program is a part of the nwg-shell project. This program is a GTK3-based wrapper to display a script output, or a text file content o

Piotr Miller 94 Dec 27, 2022
Geometry-Consistent Neural Shape Representation with Implicit Displacement Fields

Geometry-Consistent Neural Shape Representation with Implicit Displacement Fields [project page][paper][cite] Geometry-Consistent Neural Shape Represe

Yifan Wang 100 Dec 19, 2022
Code for PED: DETR For (Crowd) Pedestrian Detection

Code for PED: DETR For (Crowd) Pedestrian Detection

36 Sep 13, 2022
Traditional Chinese Text Recognition Dataset: Synthetic Dataset and Labeled Data

Traditional Chinese Text Recognition Dataset: Synthetic Dataset and Labeled Data Authors: Yi-Chang Chen, Yu-Chuan Chang, Yen-Cheng Chang and Yi-Ren Ye

Yi-Chang Chen 5 Dec 15, 2022
The Easy-to-use Dialogue Response Selection Toolkit for Researchers

The Easy-to-use Dialogue Response Selection Toolkit for Researchers

GMFTBY 32 Nov 13, 2022
Some embedding layer implementation using ivy library

ivy-manual-embeddings Some embedding layer implementation using ivy library. Just for fun. It is based on NYCTaxiFare dataset from kaggle (cut down to

Ishtiaq Hussain 2 Feb 10, 2022
nlp-tutorial is a tutorial for who is studying NLP(Natural Language Processing) using Pytorch

nlp-tutorial is a tutorial for who is studying NLP(Natural Language Processing) using Pytorch. Most of the models in NLP were implemented with less than 100 lines of code.(except comments or blank li

Tae-Hwan Jung 11.9k Jan 08, 2023
Turn clang-tidy warnings and fixes to comments in your pull request

clang-tidy pull request comments A GitHub Action to post clang-tidy warnings and suggestions as review comments on your pull request. What platisd/cla

Dimitris Platis 30 Dec 13, 2022
An implementation of WaveNet with fast generation

pytorch-wavenet This is an implementation of the WaveNet architecture, as described in the original paper. Features Automatic creation of a dataset (t

Vincent Herrmann 858 Dec 27, 2022