Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiaobin-Rong authored May 14, 2024
1 parent 6ce40ce commit 0a57b40
Showing 1 changed file with 47 additions and 31 deletions.
78 changes: 47 additions & 31 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""
multiple GPUs version, using DDP training.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1"
import toml
import torch
import random
import argparse
import numpy as np
import torch.distributed as dist

from trainer import Trainer
Expand All @@ -14,37 +12,50 @@
from loss_factory import loss_wavmag, loss_mse, loss_hybrid

seed = 0
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# torch.backends.cudnn.deterministic =True

def run(rank, config, args):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12354'
dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
torch.cuda.set_device(rank)
dist.barrier()


args.rank = rank
args.device = torch.device(rank)

train_dataset = MyDataset(**config['train_dataset'], **config['FFT'])
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, sampler=train_sampler,
**config['train_dataloader'])

validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT'])
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset)
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, sampler=validation_sampler,
**config['validation_dataloader'])

if args.world_size > 1:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12354'
dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
torch.cuda.set_device(rank)
dist.barrier()

train_dataset = MyDataset(**config['train_dataset'])
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, sampler=train_sampler,
**config['train_dataloader'], shuffle=False)

validation_dataset = MyDataset(**config['validation_dataset'])
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset)
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, sampler=validation_sampler,
**config['validation_dataloader'], shuffle=False)
else:
train_dataset = MyDataset(**config['train_dataset'])
train_sampler = None
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, **config['train_dataloader'], shuffle=True)

validation_dataset = MyDataset(**config['validation_dataset'])
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, **config['validation_dataloader'], shuffle=False)

model = DPCRN(**config['network_config'])
model.to(args.device)

# convert to DDP model
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
if args.world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

optimizer = torch.optim.Adam(params=model.parameters(), lr=config['optimizer']['lr'])


if config['loss']['loss_func'] == 'wav_mag':
loss = loss_wavmag()
Expand All @@ -61,17 +72,22 @@ def run(rank, config, args):

trainer.train()

dist.destroy_process_group()
if args.world_size > 1:
dist.destroy_process_group()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-C', '--config', default='config.toml')
parser.add_argument('-C', '--config', default='cfg_train.toml')
parser.add_argument('-D', '--device', default='0', help='The index of the available devices, e.g. 0,1,2,3')

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.world_size = len(args.device.split(','))
config = toml.load(args.config)
args.world_size = config['DDP']['world_size']
torch.multiprocessing.spawn(
run, args=(config, args,), nprocs=args.world_size, join=True)


if args.world_size > 1:
torch.multiprocessing.spawn(
run, args=(config, args,), nprocs=args.world_size, join=True)
else:
run(0, config, args)

0 comments on commit 0a57b40

Please sign in to comment.