Skip to content

Latest commit



216 lines (164 loc) · 6.75 KB

File metadata and controls

216 lines (164 loc) · 6.75 KB

hmByT5 JAX/FLAX pretraining

We use the official JAX/FLAX example in Hugging Face Transformers to pretrain a ByT5 model on a single v3-8 TPU.

The following steps are adopted from the TPU CM Cheatsheet.

TPU VM Setup

Library Version
JAX 0.3.25
FLAX 0.6.4
Datasets 2.10.1
Transformers 4.27.1
Chex 0.1.6

Please note that it could work with later versions - but it's not guaranteed ;)

Create disk with additional storage

gcloud compute disks create lms --zone us-central1-a --size 1024G

Make sure, that your disk is in the same zone as your TPU VM!

Create v3-8 TPU VM

The following commands creates a v3-8 TPU VM and attaches the previously created disk to it:

gcloud alpha compute tpus tpu-vm create hmbyt5 --zone us-central1-a \
--accelerator-type v3-8 \
--version tpu-vm-base \
--data-disk source=projects/<project-name>/zones/us-central1-a/disks/lms


Just run the following command to SSH into the TPU VM:

gcloud alpha compute tpus tpu-vm ssh hmbyt5 --zone us-central1-a 

Installation of libraries

After ssh'ing into TPU VM, run the following commands in e.g. tmux.

sudo apt update -y && sudo apt install -y python3-venv
python3 -m venv $HOME/dev
source $HOME/dev/bin/activate
pip install "jax[tpu]==0.3.25" -f
pip install ipython requests
git clone
git clone
git clone
cd transformers && git checkout v4.27.1 && pip3 install -e . && cd ..
cd datasets && git checkout 2.10.1 && pip3 install -e . && cd ..
cd flax && git checkout v0.6.4 && pip3 install -e . && cd ..
pip install chex==0.1.6

# Useful symlinks
ln -s $HOME/transformers/examples/flax/language-modeling/

Format and mount disk

The attached disk needs to formatted first using:

sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb

After that it can be mounted via:

sudo mkdir -p /mnt/datasets
sudo mount -o discard,defaults /dev/sdb /mnt/datasets/
sudo chmod a+w /mnt/datasets

HF Datasets Cache

The HF dataset cache variable should now point to the mounted disk:

export HF_DATASETS_CACHE=/mnt/datasets/huggingface

Create swapfile

The following commands create and activate a swapfile:

cd /mnt/datasets
sudo fallocate -l 10GB ./swapfile
sudo chmod 600 ./swapfile
sudo mkswap ./swapfile
sudo swapon ./swapfile


Install TensorBoard to get better training metric visualizations:

pip install tensorboard==2.12.1 tensorflow==2.12

Hugging Face Model Hub Login

In order to push all checkpoints directly to the Model Hub, we need to setup Git-LFS first:

curl -s | sudo bash
sudo apt install -y git-lfs
git lfs install
git config --global credential.helper store

After that, Model Hub credentials need to be stored:

huggingface-cli login

Validation Data

We use available training splits from NER corpora to construct a validation dataset. The can be used to create validation splits for all languages.



In our first experiment, we train for one epoch over the English (blbooks) corpus, using the following command:

python \
--model_name_or_path="google/byt5-small" \
--output_dir="/mnt/datasets/byt5-small-english" \
--max_seq_length="1024" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--learning_rate="0.0003" \
--warmup_steps="10000" \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--train_file="/mnt/datasets/corpora/english.txt" \
--validation_file="/mnt/datasets/validation/en_validation.txt" \
--hub_model_id="stefan-it/byt5-small-english" \
--num_train_epochs="1" \
--preprocessing_num_workers="16" \

Checkpoints and the TensorBoard logs are automatically uploaded to the Model Hub, and can be found here.


In the second experiment, we use the previously pretrained English model as initial checkpoint on pretrain on the German corpus - using the last learning rate and zero warmup steps:

python \
--model_name_or_path="/mnt/datasets/byt5-small-english" \
--output_dir="/mnt/datasets/byt5-small-english-german" \
--max_seq_length="1024" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--learning_rate="4.955113013238588e-07" \
--warmup_steps="0" \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--train_file="/mnt/datasets/corpora/german.txt" \
--validation_file="/mnt/datasets/validation/de_validation.txt" \
--hub_model_id="stefan-it/byt5-small-english-german" \
--num_train_epochs="1" \
--preprocessing_num_workers="16" \

Pretraining - Multilingual model on subcorpus

In another experiment we sample 4B bytes (~4GB of text) from each corpora (and upsample Swedish and Finnish). We extend the JAX/FLAX pretraining script, so that is possible to perform evaluations on multiple validation datasets. Thus, we see a detailed accuracy and loss curve for each validation dataset for a certain language. The modified script can be found under

The pretraining was started for one epoch with:

python \
--model_name_or_path="google/byt5-small" \
--output_dir="/mnt/datasets/byt5-small-multilingual-4g" \
--max_seq_length="1024" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--learning_rate="0.0003" \
--warmup_steps="10000" \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--train_file="/mnt/datasets/corpus/combined.txt" \
--validation_file="en:/mnt/datasets/validation/en_validation.txt,de:/mnt/datasets/validation/de_validation.txt,fr:/mnt/datasets/validation/fr_validation.txt,fi:/mnt/datasets/validation/fi_validation.txt,sv:/mnt/datasets/validation/sv_validation.txt,nl:/mnt/datasets/validation/nl_validation.txt" \
--hub_model_id="stefan-it/byt5-small-multilingual-4g" \
--num_train_epochs="1" \
--preprocessing_num_workers="16" \