From 5d2f03b127201af9b0a8414641551a6ac7eb611d Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Wed, 20 Sep 2023 22:01:37 +0800 Subject: [PATCH 1/3] [wenet] use torchrun for distributed training --- examples/aishell/s0/run.sh | 33 ++++--------------- wenet/bin/train.py | 65 ++++++++++++-------------------------- 2 files changed, 28 insertions(+), 70 deletions(-) diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index d3ff2ddfa..eb7e68b4a 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -6,23 +6,16 @@ # Use this to control how many gpu you use, It's 1-gpu training if you specify # just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" -# The NCCL_SOCKET_IFNAME variable specifies which IP interface to use for nccl -# communication. More details can be found in -# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html -# export NCCL_SOCKET_IFNAME=ens4f1 -export NCCL_DEBUG=INFO + stage=0 # start from 0 if you need to start from data preparation stop_stage=5 -# The num of machines(nodes) for multi-machine training, 1 is for one machine. -# NFS is required if num_nodes > 1. +# The aishell dataset location, please change this to your own path +# You should change the following two parameters for multiple machine training, +# see https://pytorch.org/docs/stable/elastic/run.html +HOST_NODE_ADDR="localhost:0" num_nodes=1 -# The rank of each node or machine, which ranges from 0 to `num_nodes - 1`. -# You should set the node_rank=0 on the first machine, set the node_rank=1 -# on the second machine, and so on. -node_rank=0 -# The aishell dataset location, please change this to your own path # make sure of using absolute path. DO-NOT-USE relatvie path! data=/export/data/asr-data/OpenSLR/33/ data_url=www.openslr.org/resources/33 @@ -128,8 +121,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') # Use "nccl" if it works, otherwise use "gloo" dist_backend="nccl" - world_size=`expr $num_gpus \* $num_nodes` - echo "total gpus is: $world_size" cmvn_opts= $cmvn && cp data/${train_set}/global_cmvn $dir $cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn" @@ -165,13 +156,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --pin_memory else echo "using torch ddp" - for ((i = 0; i < $num_gpus; ++i)); do - { - gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) - # Rank of each gpu/process used for knowing whether it is - # the master of a worker. - rank=`expr $node_rank \* $num_gpus + $i` - python wenet/bin/train.py --gpu $gpu_id \ + torchrun --standalone --nnodes=$num_nodes --nproc_per_node=$num_gpus --rdzv_endpoint=$HOST_NODE_ADDR \ + wenet/bin/train.py \ --config $train_config \ --data_type $data_type \ --symbol_table $dict \ @@ -180,16 +166,11 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then ${checkpoint:+--checkpoint $checkpoint} \ --model_dir $dir \ --ddp.init_method $init_method \ - --ddp.world_size $world_size \ - --ddp.rank $rank \ --ddp.dist_backend $dist_backend \ --num_workers ${num_workers} \ --prefetch ${prefetch} \ $cmvn_opts \ --pin_memory - } & - done - wait fi fi diff --git a/wenet/bin/train.py b/wenet/bin/train.py index da9a6f6bb..4c0397a1d 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -51,26 +51,11 @@ def get_args(): help='train and cv data type') parser.add_argument('--train_data', required=True, help='train data file') parser.add_argument('--cv_data', required=True, help='cv data file') - parser.add_argument('--gpu', - type=int, - default=-1, - help='gpu id for this local rank, -1 for cpu') parser.add_argument('--model_dir', required=True, help='save model dir') parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--tensorboard_dir', default='tensorboard', help='tensorboard log dir') - parser.add_argument('--ddp.rank', - dest='rank', - default=0, - type=int, - help='global rank for distributed training') - parser.add_argument('--ddp.world_size', - dest='world_size', - default=-1, - type=int, - help='''number of total processes/gpus for - distributed training''') parser.add_argument('--ddp.dist_backend', dest='dist_backend', default='nccl', @@ -149,9 +134,6 @@ def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') - # NOTE(xcsong): deepspeed set CUDA_VISIBLE_DEVICES internally - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if not args.deepspeed \ - else os.environ['CUDA_VISIBLE_DEVICES'] # Set random seed torch.manual_seed(777) @@ -169,27 +151,22 @@ def main(): else: configs["ds_dtype"] = "fp32" - # deepspeed read world_size from env - if args.deepspeed: - assert args.world_size == -1 - # distributed means pytorch native ddp, it parse world_size from args - distributed = args.world_size > 1 - local_rank = args.rank - world_size = args.world_size + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + distributed = world_size > 1 if distributed: - logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) + logging.info('training on multiple gpus, this gpu {}'.format(local_rank)) + torch.cuda.set_device(local_rank) dist.init_process_group(args.dist_backend, init_method=args.init_method, world_size=world_size, - rank=local_rank) + rank=rank) elif args.deepspeed: - # Update local_rank & world_size from enviroment variables - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) deepspeed.init_distributed(dist_backend=args.dist_backend, init_method=args.init_method, - rank=local_rank, - world_size=world_size) + world_size=world_size, + rank=rank) symbol_table = read_symbol_table(args.symbol_table) @@ -264,7 +241,7 @@ def main(): configs['is_json_cmvn'] = True configs['lfmmi_dir'] = args.lfmmi_dir - if local_rank == 0: + if rank == 0: saved_config_path = os.path.join(args.model_dir, 'train.yaml') with open(saved_config_path, 'w') as fout: data = yaml.dump(configs) @@ -279,7 +256,7 @@ def main(): # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements - if local_rank == 0: + if rank == 0: script_model = torch.jit.script(model) script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() @@ -298,7 +275,7 @@ def main(): num_epochs = configs.get('max_epoch', 100) model_dir = args.model_dir writer = None - if local_rank == 0: + if rank == 0: os.makedirs(model_dir, exist_ok=True) exp_id = os.path.basename(model_dir) writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) @@ -320,7 +297,7 @@ def main(): elif args.deepspeed: # deepspeed # NOTE(xcsong): look in detail how the memory estimator API works: # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion - if local_rank == 0: + if rank == 0: logging.info("Estimating model states memory needs (zero2)...") estimate_zero2_model_states_mem_needs_all_live( model, num_gpus_per_node=world_size, num_nodes=1) @@ -330,7 +307,7 @@ def main(): device = None # Init device later pass # Init DeepSpeed later else: - use_cuda = args.gpu >= 0 and torch.cuda.is_available() + use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) @@ -370,7 +347,7 @@ def scheduler(opt): lr_scheduler=scheduler, model_parameters=model.parameters()) final_epoch = None - configs['rank'] = local_rank + configs['rank'] = rank configs['is_distributed'] = distributed # pytorch native ddp configs['is_deepspeed'] = args.deepspeed # deepspeed configs['use_amp'] = args.use_amp @@ -380,11 +357,11 @@ def scheduler(opt): # https://github.com/microsoft/DeepSpeed/issues/2993 with torch.no_grad(): model.save_checkpoint(save_dir=model_dir, tag='init') - if args.save_states == "model_only" and local_rank == 0: + if args.save_states == "model_only" and rank == 0: convert_zero_checkpoint_to_fp32_state_dict( model_dir, "{}/init.pt".format(model_dir), tag='init') os.system("rm -rf {}/{}".format(model_dir, "init")) - elif not args.deepspeed and start_epoch == 0 and local_rank == 0: + elif not args.deepspeed and start_epoch == 0 and rank == 0: save_model_path = os.path.join(model_dir, 'init.pt') save_checkpoint(model, save_model_path) @@ -413,7 +390,7 @@ def scheduler(opt): 'epoch': epoch, 'lr': lr, 'cv_loss': cv_loss, 'step': executor.step, 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') } - if local_rank == 0: + if rank == 0: writer.add_scalar('epoch/cv_loss', cv_loss, epoch) writer.add_scalar('epoch/lr', lr, epoch) with open("{}/{}.yaml".format(model_dir, epoch), 'w') as fout: @@ -427,17 +404,17 @@ def scheduler(opt): model.save_checkpoint(save_dir=model_dir, tag='{}'.format(epoch), client_state=infos) - if args.save_states == "model_only" and local_rank == 0: + if args.save_states == "model_only" and rank == 0: convert_zero_checkpoint_to_fp32_state_dict( model_dir, "{}/{}.pt".format(model_dir, epoch), tag='{}'.format(epoch)) os.system("rm -rf {}/{}".format(model_dir, epoch)) - elif not args.deepspeed and local_rank == 0: + elif not args.deepspeed and rank == 0: save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) save_checkpoint(model, save_model_path, infos) final_epoch = epoch - if final_epoch is not None and local_rank == 0: + if final_epoch is not None and rank == 0: final_model_path = os.path.join(model_dir, 'final.pt') os.remove(final_model_path) if os.path.exists(final_model_path) else None os.symlink('{}.pt'.format(final_epoch), final_model_path) From 2f1a3fad4eadcb2cbfd02579cce463b29fc2cacf Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Wed, 20 Sep 2023 22:05:33 +0800 Subject: [PATCH 2/3] fix topo --- examples/aishell/s0/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index eb7e68b4a..1ddb81e5f 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -10,12 +10,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" stage=0 # start from 0 if you need to start from data preparation stop_stage=5 -# The aishell dataset location, please change this to your own path # You should change the following two parameters for multiple machine training, # see https://pytorch.org/docs/stable/elastic/run.html HOST_NODE_ADDR="localhost:0" num_nodes=1 +# The aishell dataset location, please change this to your own path # make sure of using absolute path. DO-NOT-USE relatvie path! data=/export/data/asr-data/OpenSLR/33/ data_url=www.openslr.org/resources/33 From 24a70a7072218f973782d0d369067dbe00ba80d6 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Wed, 20 Sep 2023 22:28:01 +0800 Subject: [PATCH 3/3] rm standalone --- examples/aishell/s0/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 1ddb81e5f..e1b377ec2 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -156,7 +156,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --pin_memory else echo "using torch ddp" - torchrun --standalone --nnodes=$num_nodes --nproc_per_node=$num_gpus --rdzv_endpoint=$HOST_NODE_ADDR \ + torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus --rdzv_endpoint=$HOST_NODE_ADDR \ wenet/bin/train.py \ --config $train_config \ --data_type $data_type \