From a57d4585a829238b1a4fb45c88333fb7d1688d0f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 2 May 2024 10:10:07 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- create_seed_checkpoint.sh | 23 +++++++++++++---------- train.py | 2 +- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 0c00145d6..38bab219f 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -5,24 +5,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# +# create_seed_checkpoint.sh +# +# Run this script to create a seed checkpoint used to initialize a model from step-0. +# Seed checkpoints are used to initialize pipeline-parallel models since the model initializer +# functions don't cleanly run on chunked model parts after meta-initialization. +# +# Use the same model config to generate your seed checkpoint as you use for training. +# e.g. +# CONFIG= ./create_seed_checkpoint.sh + set -ex -# libUV is a scalable backend for TCPStore which is used in processGroup -# rendezvous. This is the recommended backend for distributed training. export USE_LIBUV=1 TRAINER_DIR=${1:-/home/$USER/local/torchtitan} - -# use envs as local overrides for convenience -# e.g. -# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh - NGPU=1 LOG_RANK=0 - - CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint" +force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1" overrides="" if [ $# -ne 0 ]; then overrides="$*" @@ -30,4 +33,4 @@ fi torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -train.py --job.config_file ${CONFIG_FILE} $seed_checkpoint $overrides +train.py --job.config_file ${CONFIG_FILE} $seed_checkpoint $force_1d $overrides diff --git a/train.py b/train.py index 553e79e63..b47cd0f2c 100644 --- a/train.py +++ b/train.py @@ -405,10 +405,10 @@ def loss_fn(pred, labels): metric_logger.close() logger.info("Training completed") - destroy_process_group() if __name__ == "__main__": config = JobConfig() config.parse_args() main(config) + destroy_process_group()