Tutorial to pretrain & fine-tune a 🤗 Flax T5 model on a TPUv3-8 with GCP

Overview

Pretrain and Fine-tune a T5 model with Flax on GCP

This tutorial details how pretrain and fine-tune a FlaxT5 model from HuggingFace using a TPU VM available on Google Cloud.

While the code is only slightly adapted from the original HuggingFace examples for pretraining and seq2seq fine-tuning, this repository is aimed to provide a comprehensive overview for the whole process, with a special focus on pitfalls due to an incorrect environment setup.

Why JAX/Flax? Thanks to the amazing work of the HuggingFace team, a good portion of the models available in the transformers library are available in Flax, a neural network library built on top of JAX. Using Flax for pretraining on TPUs is especially convenient for two reasons:

  • TPUs are optimized for training deep models, and Flax is optimized for TPUs. A benchmark conducted by Huggingface showed that a BERT pretraining with Flax takes only 65% of the time needed with Pytorch/XLA to achieve comparable performances on a TPUv3-8 VM, and only 35% of the time if compared to Pytorch training on 8 V100 GPUs. Since language models are usually pretrained for days on massive amounts of text, the speedup is especially desirable here.

  • Out-of-the-box compatibility with Pytorch and Tensorflow in the transformers framework. A Flax model can be easily converted in Pytorch, for example, by using T5ForConditionalGeneration.from_pretrained("path/to/flax/ckpt", from_flax=True).

The code and instructions contained in this repository were used to pretrain the models gsarti/t5-base-it and gsarti/t5-large-it available on the Huggingface Hub, using ~270Gb of cleaned web-scraped Italian texts. The dataset is also made available on the Huggingface Hub under the name gsarti/clean_mc4_it.

Setup on your machine

Follow the instructions detailed in the Cloud SDK Install Guide to install the gcloud client on your system.

TPU VMs only have ~100Gb of disk space, which is highly unlikely to be enough to store your raw + preprocessed dataset and all model checkpoints. GCP gives a choice among different kind of disk types and limits the total disk space a user can create. At the time of writing this, for a non-free trial user in Europe without special access the limit is 250Gb for SSD (pd-balanced) and 2Tb overall, including SSD and HDD (pd-standard).

### Define your variables
export GCP_PROJECT="<YOUR_PROJECT_NAME>"
export GCP_ZONE="<YOUR_REGION>"
export GCP_TPU_NAME="<YOUR_TPU_NAME>"

# >>>>> Run this part to add a disk to your TPU VM
export GCP_DISK_NAME="<YOUR_DISK_NAME>"
export GCP_DISK_SIZE_GB=1200
export GCP_DISK_TYPE=pd-standard


gcloud beta compute disks create $GCP_DISK_NAME \
    --project=$GCP_PROJECT \
    --type=$GCP_DISK_TYPE \
    --size="${GCP_DISK_SIZE_GB}GB" \
    --zone=$GCP_ZONE
# <<<<<

# Create the TPU VM
gcloud alpha compute tpus tpu-vm create $GCP_TPU_NAME \
    --zone $GCP_ZONE \
    --project $GCP_PROJECT \
    --accelerator-type v3-8 \
    --version v2-alpha \
    # Uncomment this if a disk is used
    #--data-disk source="projects/${GCP_PROJECT}/zones/${GCP_ZONE}/disks/${GCP_DISK_NAME}" 

Inside the TPUv3-8 machine

Log in inside the machine with the following command:

gcloud alpha compute tpus tpu-vm ssh $GCP_TPU_NAME --zone $GCP_ZONE --project $GCP_PROJECT

The system is a classic Ubuntu 20.04 with some utility libraries already available (e.g. tmux, vim). The following commands show how to setup the disk inside the machine, if you created a disk. The disk in the example is available on /dev/sdb, and we mount it on the data folder in the home:

# Check that your disk is visible and get its name
lsblk
# Mount disk in the data folder
sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb
sudo mkdir -p data
sudo mount -o discard,defaults /dev/sdb data
sudo chmod a+w data
# Optionally, add the automatic mounting: https://cloud.google.com/compute/docs/disks/add-persistent-disk#configuring_automatic_mounting_on_vm_restart

# Permanently export the disk as the new path for HF dataset caching
echo "export HF_DATASETS_CACHE=data" >> .bashrc
source .bashrc

We then setup the environment with required libraries and dependencies:

### Setup the TPUv3-8 environment
sudo apt update
sudo apt-get install python3.8-venv
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
sudo apt-get install git-lfs
git lfs install

The next step is to create a repository on the Huggingface Hub (here e.g. t5-base-it) that will contain all our files. We then train a tokenizer, select a config, and push the two JSON files in the repository we created.

Depending on the number of sentences that need to be processed, this may take quite some time (~2 hours for 10M sentences).

# Huggingface Hub variables
export HF_NAME="<YOUR_HF_HUB_USERNAME>"
export HF_PWD="<YOUR_HF_HUB_PASSWORD>"
export HF_TOKEN="<YOUR_HF_HUB_API_TOKEN>"
export HF_PROJECT="t5-base-it"

# Variables for training the tokenizer and creating the config
export VOCAB_SIZE="32000"
export N_INPUT_SENTENCES="1000000" # Num of sentences to train the tokenizer
export DATASET="gsarti/clean_mc4_it" # Name of the dataset in the Huggingface Hub
export DATASET_CONFIG="full" # Config of the dataset in the Huggingface Hub 
export DATASET_SPLIT="train" # Split to use for training tokenizer and model
export TEXT_FIELD="text" # Field containing the text to be used for training
export CONFIG_TYPE="google/t5-v1_1-base" # Config that our model will use
export MODEL_PATH="data/${HF_PROJECT}" # Path to the model, e.g. here inside the mount

### Setup the project on the HF Hub
#huggingface-cli login also works
echo $HF_TOKEN > ~/.huggingface/token
huggingface-cli repo create $HF_PROJECT
git clone "https://huggingface.co/${HF_NAME}/${HF_PROJECT}"
mv $HF_PROJECT $MODEL_PATH

# Create the tokenizer and the config
python create_tokenizer_cfg.py \
    --model_dir $MODEL_PATH \
    --dataset $DATASET \
    --dataset_config $DATASET_CONFIG \
    --dataset_split $DATASET_SPLIT \
    --text_field $TEXT_FIELD \
    --vocab_size $VOCAB_SIZE \
    --input_sentence_size $N_INPUT_SENTENCES \
    --config_type $CONFIG_TYPE

# Push the tokenizer and the config to the HF Hub
cd $MODEL_PATH
git lfs track "*tfevents*"
git add .
git commit -m "Added tokenizer and config"à
# If this doesn't work, do a normal push and insert your credentials
git push "https://huggingface.co/${HF_NAME}/${HF_NAME}:${HF_PWD}@${HF_PROJECT}"
cd ..

The last step is to run the pretraining. The following command will train the model on the same dataset used to train the tokenizer, logging train/eval metrics to Tensorboard and pushing model checkpoints and artifacts to the Hub every save_steps steps:

Important: Since checkpoints are stored with Git LFS on the Hub, depending on the number of save_steps you may encur in problems with space (e.g. training a model with 1Gb ckpts for 1M steps with save_steps=10_000 will need 100Gb of free space to avoid crashes during training). Two ways to avoid problems are:

  • Use save_steps > (tot_steps / your_available_memory) * ckpt_size_gb. E.g. if we are doing 1M steps and the ckpt_size is 1GB and we only have 76Gb free, we will need save_steps higher than 1M/76Gb ~= 13158. Failing to do so will result in an interruption of the training once the memory has been exceeded.

  • Use git lfs prune --verify-remote to periodically remove old cached checkpoints from .git/lfs/objects and free up some space. If done often enough this will prevent any problem of space, but requires manual intervention of the user.

# Run the training
# Adjust the parameters to your needs
# To avoid using an extra disk for storing the model, 
python run_t5_mlm_flax.py \
    --output_dir=$MODEL_PATH \
    --model_type="t5" \
    --config_name=$MODEL_PATH \
    --tokenizer_name=$MODEL_PATH \
    --preprocessing_num_workers="96" \
    --do_train --do_eval \
    --adafactor \
    --dataset_name="gsarti/clean_mc4_it" \
    --dataset_config_name="full" \
    --max_seq_length="512" \
    --per_device_train_batch_size="8" \
    --per_device_eval_batch_size="8" \
    --learning_rate="0.005" \
    --overwrite_output_dir \
    --num_train_epochs="1" \
    --logging_steps="500" \
    --save_steps="80000" \
    --eval_steps="2500" \
    --weight_decay="0.01" \
    --warmup_steps="10000" \
    --validation_split_count="15000" \
    --push_to_hub \
    #--gradient_accumulation_steps="2" # Uncomment to use gradient accumulation
    #--resume_from_checkpoint=$MODEL_PATH # Uncomment to resume from ckpt

Exporting Checkpoints

Your model is now trained, congratulations! 🎉 Now you may want to export your final Flax model to other frameworks. This can be done easily:

python export_checkpoint.py --model_dir $MODEL_PATH
cd $MODEL_PATH
git commit -m "Added TF and PT models"
git push

Fine-tune a pretrained T5 model in Flax

TODO

Useful Tips

About GCP Billing Except for TPU costs, which can be offset by taking part in the TRC program, the second highest billing voice is the uploading of files from the VM to another destination. This happens every save_step during training, to keep the Tensorboard and the checkpoint on the HF Hub in sync. For in-region uploads, at the time of writing this the cost is up to 0.12 USD per GB. This means that for a full pretraining of 1M steps with logging every 10k steps, with checkpoints of roughly 1GB each, the cost is roughly 12$USD. To reduce this cost simply increase the amount of save_steps in the run_t5_mlm_flax.py script.

About data preprocessing Imagine you've trained a base model and now you want to train a large variant, using the same tokenzer. Even if the tokenizer is the same, if the path of the tokenizer is changed the cache will be invalidated and everything will be recomputed. If this is the case, it's not a tragedy: some extra preprocessing time will be required. But mind that the tokenization + grouping take substantial space on disk, so if the space is not enough, you will need to delete the arrows from the previous HuggingFace dataset cache. You can easily do so with something like:

find . -type f -size +2G -size -4G -exec ls -lah {} + | grep 'Aug 25 13' | wc -l

As an example, knowing that my tokenization completed Aug 25 at 13:xx and that files in the cache are 3Gb in size each, this should return the value you set for n_proc. You can get the filenames of the old preprocessing caches by removing the wc -l pipe and delete them all to leave space for the new ones.

Owner
Gabriele Sarti
PhD Student in NLP @ University of Groningen | Prev: Aindo, ILC-CNR Pisa, Skytech Montreal
Gabriele Sarti
Code for paper Multitask-Finetuning of Zero-shot Vision-Language Models

Code for paper Multitask-Finetuning of Zero-shot Vision-Language Models

Zhenhailong Wang 2 Jul 15, 2022
🤗Transformers: State-of-the-art Natural Language Processing for Pytorch and TensorFlow 2.0.

State-of-the-art Natural Language Processing for PyTorch and TensorFlow 2.0 🤗 Transformers provides thousands of pretrained models to perform tasks o

Hugging Face 77.3k Jan 03, 2023
Twitter Sentiment Analysis using #tag, words and username

Twitter Sentment Analysis Web App using #tag, words and username to fetch data finds Insides of data and Tells Sentiment of the perticular #tag, words or username.

Kumar Saksham 26 Dec 25, 2022
Big Bird: Transformers for Longer Sequences

BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the c

Google Research 457 Dec 23, 2022
Neural network models for joint POS tagging and dependency parsing (CoNLL 2017-2018)

Neural Network Models for Joint POS Tagging and Dependency Parsing Implementations of joint models for POS tagging and dependency parsing, as describe

Dat Quoc Nguyen 152 Sep 02, 2022
Takes a string and puts it through different languages in Google Translate a requested amount of times, returning nonsense.

PythonTextObfuscator Takes a string and puts it through different languages in Google Translate a requested amount of times, returning nonsense. Requi

2 Aug 29, 2022
Ongoing research training transformer language models at scale, including: BERT & GPT-2

Megatron (1 and 2) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA.

NVIDIA Corporation 3.5k Dec 30, 2022
Let Xiao Ai speakers control third-party devices

A stupid way to extend miot/xiaoai. Demo for Panasonic Bath Bully FV-RB20VL1 逆向 Panasonic Smart China,获得控制浴霸的请求信息(HTTP 请求),详见 apps/panasonic.py; 2. 通过

bin 14 Jul 07, 2022
Learning to Rewrite for Non-Autoregressive Neural Machine Translation

RewriteNAT This repo provides the code for reproducing our proposed RewriteNAT in EMNLP 2021 paper entitled "Learning to Rewrite for Non-Autoregressiv

Xinwei Geng 20 Dec 25, 2022
Search for documents in a domain through Google. The objective is to extract metadata

MetaFinder - Metadata search through Google _____ __ ___________ .__ .___ / \

Josué Encinar 85 Dec 16, 2022
Russian GPT3 models.

Russian GPT-3 models (ruGPT3XL, ruGPT3Large, ruGPT3Medium, ruGPT3Small) trained with 2048 sequence length with sparse and dense attention blocks. We also provide Russian GPT-2 large model (ruGPT2Larg

Sberbank AI 1.6k Jan 05, 2023
A collection of GNN-based fake news detection models.

This repo includes the Pytorch-Geometric implementation of a series of Graph Neural Network (GNN) based fake news detection models. All GNN models are implemented and evaluated under the User Prefere

SafeGraph 251 Jan 01, 2023
Sentello is python script that simulates the anti-evasion and anti-analysis techniques used by malware.

sentello Sentello is a python script that simulates the anti-evasion and anti-analysis techniques used by malware. For techniques that are difficult t

Malwation 62 Oct 02, 2022
In this project, we aim to achieve the task of predicting emojis from tweets. We aim to investigate the relationship between words and emojis.

Making Emojis More Predictable by Karan Abrol, Karanjot Singh and Pritish Wadhwa, Natural Language Processing (CSE546) under the guidance of Dr. Shad

Karanjot Singh 2 Jan 17, 2022
State-of-the-art NLP through transformer models in a modular design and consistent APIs.

Trapper (Transformers wRAPPER) Trapper is an NLP library that aims to make it easier to train transformer based models on downstream tasks. It wraps h

Open Business Software Solutions 42 Sep 21, 2022
justCTF [*] 2020 challenges sources

justCTF [*] 2020 This repo contains sources for justCTF [*] 2020 challenges hosted by justCatTheFish. TLDR: Run a challenge with ./run.sh (requires Do

justCatTheFish 25 Dec 27, 2022
Treemap visualisation of Maya scene files

Ever wondered which nodes are responsible for that 600 mb+ Maya scene file? Features Fast, resizable UI Parsing at 50 mb/sec Dependency-free, single-f

Marcus Ottosson 76 Nov 12, 2022
A list of NLP(Natural Language Processing) tutorials

NLP Tutorial A list of NLP(Natural Language Processing) tutorials built on PyTorch. Table of Contents A step-by-step tutorial on how to implement and

Allen Lee 1.3k Dec 25, 2022
GooAQ 🥑 : Google Answers to Google Questions!

This repository contains the code/data accompanying our recent work on long-form question answering.

AI2 112 Nov 06, 2022
🧪 Cutting-edge experimental spaCy components and features

spacy-experimental: Cutting-edge experimental spaCy components and features This package includes experimental components and features for spaCy v3.x,

Explosion 65 Dec 30, 2022