We use the official JAX/FLAX example in Hugging Face Transformers to pretrain a ByT5 model on a single v3-8 TPU.
The following steps are adopted from the TPU CM Cheatsheet.
Library | Version |
---|---|
JAX | 0.3.25 |
FLAX | 0.6.4 |
Datasets | 2.10.1 |
Transformers | 4.27.1 |
Chex | 0.1.6 |
Please note that it could work with later versions - but it's not guaranteed ;)
gcloud compute disks create lms --zone us-central1-a --size 1024G
Make sure, that your disk is in the same zone
as your TPU VM!
The following commands creates a v3-8 TPU VM and attaches the previously created disk to it:
gcloud alpha compute tpus tpu-vm create hmbyt5 --zone us-central1-a \
--accelerator-type v3-8 \
--version tpu-vm-base \
--data-disk source=projects/<project-name>/zones/us-central1-a/disks/lms
Just run the following command to SSH into the TPU VM:
gcloud alpha compute tpus tpu-vm ssh hmbyt5 --zone us-central1-a
After ssh'ing into TPU VM, run the following commands in e.g. tmux
.
sudo apt update -y && sudo apt install -y python3-venv
python3 -m venv $HOME/dev
source $HOME/dev/bin/activate
pip install "jax[tpu]==0.3.25" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install ipython requests
git clone https://github.com/huggingface/transformers.git
git clone https://github.com/huggingface/datasets.git
git clone https://github.com/google/flax.git
cd transformers && git checkout v4.27.1 && pip3 install -e . && cd ..
cd datasets && git checkout 2.10.1 && pip3 install -e . && cd ..
cd flax && git checkout v0.6.4 && pip3 install -e . && cd ..
pip install chex==0.1.6
# Useful symlinks
ln -s $HOME/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py
The attached disk needs to formatted first using:
sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb
After that it can be mounted via:
sudo mkdir -p /mnt/datasets
sudo mount -o discard,defaults /dev/sdb /mnt/datasets/
sudo chmod a+w /mnt/datasets
The HF dataset cache variable should now point to the mounted disk:
export HF_DATASETS_CACHE=/mnt/datasets/huggingface
The following commands create and activate a swapfile:
cd /mnt/datasets
sudo fallocate -l 10GB ./swapfile
sudo chmod 600 ./swapfile
sudo mkswap ./swapfile
sudo swapon ./swapfile
Install TensorBoard to get better training metric visualizations:
pip install tensorboard==2.12.1 tensorflow==2.12
In order to push all checkpoints directly to the Model Hub, we need to setup Git-LFS first:
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
sudo apt install -y git-lfs
git lfs install
git config --global credential.helper store
After that, Model Hub credentials need to be stored:
huggingface-cli login
We use available training splits from NER corpora to construct a validation dataset.
The create_validation_data.py
can be used to create validation splits for all languages.
In our first experiment, we train for one epoch over the English (blbooks) corpus, using the following command:
python run_t5_mlm_flax.py \
--model_name_or_path="google/byt5-small" \
--output_dir="/mnt/datasets/byt5-small-english" \
--max_seq_length="1024" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--learning_rate="0.0003" \
--warmup_steps="10000" \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--train_file="/mnt/datasets/corpora/english.txt" \
--validation_file="/mnt/datasets/validation/en_validation.txt" \
--hub_model_id="stefan-it/byt5-small-english" \
--num_train_epochs="1" \
--preprocessing_num_workers="16" \
--push_to_hub
Checkpoints and the TensorBoard logs are automatically uploaded to the Model Hub, and can be found here.
In the second experiment, we use the previously pretrained English model as initial checkpoint on pretrain on the German corpus - using the last learning rate and zero warmup steps:
python run_t5_mlm_flax.py \
--model_name_or_path="/mnt/datasets/byt5-small-english" \
--output_dir="/mnt/datasets/byt5-small-english-german" \
--max_seq_length="1024" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--learning_rate="4.955113013238588e-07" \
--warmup_steps="0" \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--train_file="/mnt/datasets/corpora/german.txt" \
--validation_file="/mnt/datasets/validation/de_validation.txt" \
--hub_model_id="stefan-it/byt5-small-english-german" \
--num_train_epochs="1" \
--preprocessing_num_workers="16" \
--push_to_hub
In another experiment we sample 4B bytes (~4GB of text) from each corpora (and upsample Swedish and Finnish).
We extend the JAX/FLAX pretraining script, so that is possible to perform evaluations on multiple validation datasets.
Thus, we see a detailed accuracy and loss curve for each validation dataset for a certain language.
The modified script can be found under run_t5_mlm_flax.py
.
The pretraining was started for one epoch with:
python run_t5_mlm_flax.py \
--model_name_or_path="google/byt5-small" \
--output_dir="/mnt/datasets/byt5-small-multilingual-4g" \
--max_seq_length="1024" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--learning_rate="0.0003" \
--warmup_steps="10000" \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--train_file="/mnt/datasets/corpus/combined.txt" \
--validation_file="en:/mnt/datasets/validation/en_validation.txt,de:/mnt/datasets/validation/de_validation.txt,fr:/mnt/datasets/validation/fr_validation.txt,fi:/mnt/datasets/validation/fi_validation.txt,sv:/mnt/datasets/validation/sv_validation.txt,nl:/mnt/datasets/validation/nl_validation.txt" \
--hub_model_id="stefan-it/byt5-small-multilingual-4g" \
--num_train_epochs="1" \
--preprocessing_num_workers="16" \
--push_to_hub