From 0a57b40c11a77a79f993f3b35b6964fc715c5f65 Mon Sep 17 00:00:00 2001 From: Rong Xiaobin Date: Tue, 14 May 2024 11:14:35 +0800 Subject: [PATCH] Update train.py --- train.py | 78 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/train.py b/train.py index efc1a3d..0304bd4 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,9 @@ -""" -multiple GPUs version, using DDP training. -""" import os -os.environ["CUDA_VISIBLE_DEVICES"]="0, 1" import toml import torch +import random import argparse +import numpy as np import torch.distributed as dist from trainer import Trainer @@ -14,37 +12,50 @@ from loss_factory import loss_wavmag, loss_mse, loss_hybrid seed = 0 +random.seed(seed) +os.environ['PYTHONHASHSEED'] = str(seed) +np.random.seed(seed) torch.manual_seed(seed) +torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - +# torch.backends.cudnn.deterministic =True def run(rank, config, args): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12354' - dist.init_process_group("nccl", rank=rank, world_size=args.world_size) - torch.cuda.set_device(rank) - dist.barrier() - + args.rank = rank args.device = torch.device(rank) - - train_dataset = MyDataset(**config['train_dataset'], **config['FFT']) - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, sampler=train_sampler, - **config['train_dataloader']) - validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT']) - validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset) - validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, sampler=validation_sampler, - **config['validation_dataloader']) - + if args.world_size > 1: + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12354' + dist.init_process_group("nccl", rank=rank, world_size=args.world_size) + torch.cuda.set_device(rank) + dist.barrier() + + train_dataset = MyDataset(**config['train_dataset']) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, sampler=train_sampler, + **config['train_dataloader'], shuffle=False) + + validation_dataset = MyDataset(**config['validation_dataset']) + validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset) + validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, sampler=validation_sampler, + **config['validation_dataloader'], shuffle=False) + else: + train_dataset = MyDataset(**config['train_dataset']) + train_sampler = None + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, **config['train_dataloader'], shuffle=True) + + validation_dataset = MyDataset(**config['validation_dataset']) + validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, **config['validation_dataloader'], shuffle=False) + model = DPCRN(**config['network_config']) model.to(args.device) - # convert to DDP model - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + if args.world_size > 1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + optimizer = torch.optim.Adam(params=model.parameters(), lr=config['optimizer']['lr']) - if config['loss']['loss_func'] == 'wav_mag': loss = loss_wavmag() @@ -61,17 +72,22 @@ def run(rank, config, args): trainer.train() - dist.destroy_process_group() + if args.world_size > 1: + dist.destroy_process_group() if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-C', '--config', default='config.toml') + parser.add_argument('-C', '--config', default='cfg_train.toml') + parser.add_argument('-D', '--device', default='0', help='The index of the available devices, e.g. 0,1,2,3') args = parser.parse_args() - + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + args.world_size = len(args.device.split(',')) config = toml.load(args.config) - args.world_size = config['DDP']['world_size'] - torch.multiprocessing.spawn( - run, args=(config, args,), nprocs=args.world_size, join=True) - + + if args.world_size > 1: + torch.multiprocessing.spawn( + run, args=(config, args,), nprocs=args.world_size, join=True) + else: + run(0, config, args)