Skip to content

Commit

Permalink
Update trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiaobin-Rong authored May 14, 2024
1 parent 0a57b40 commit 13b22dd
Showing 1 changed file with 43 additions and 50 deletions.
93 changes: 43 additions & 50 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""
multiple GPUs version, using DDP training.
"""
import os
import torch
import toml
from datetime import datetime
from tqdm import tqdm
from glob import glob
from pesq import pesq
from joblib import Parallel, delayed
import soundfile as sf
from torch.utils.tensorboard import SummaryWriter
from pesq import pesq
from distributed_utils import reduce_value


Expand All @@ -18,23 +16,20 @@ def __init__(self, config, model, optimizer, loss_func,
train_dataloader, validation_dataloader, train_sampler, args):
self.config = config
self.model = model

self.optimizer = optimizer

self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, 'min', factor=0.5, patience=5,verbose=True)

self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, [80, 120, 150, 170, 180, 190, 200], gamma=0.5, verbose=False)
self.loss_func = loss_func

self.train_dataloader = train_dataloader
self.validation_dataloader = validation_dataloader

self.train_sampler = train_sampler
self.args = args
self.rank = args.rank
self.device = args.device
self.world_size = args.world_size

## training config
# training config
self.trainer_config = config['trainer']
self.epochs = self.trainer_config['epochs']
self.save_checkpoint_interval = self.trainer_config['save_checkpoint_interval']
Expand All @@ -43,7 +38,7 @@ def __init__(self, config, model, optimizer, loss_func,

if not self.resume:
self.exp_path = self.trainer_config['exp_path'] + '_' + datetime.now().strftime("%Y-%m-%d-%Hh%Mm")

else:
self.exp_path = self.trainer_config['exp_path'] + '_' + self.trainer_config['resume_datetime']

Expand All @@ -55,7 +50,7 @@ def __init__(self, config, model, optimizer, loss_func,
os.makedirs(self.checkpoint_path, exist_ok=True)
os.makedirs(self.sample_path, exist_ok=True)

## save the config
# save the config
if self.rank == 0:
with open(
os.path.join(
Expand Down Expand Up @@ -83,9 +78,11 @@ def _set_eval_mode(self):
self.model.eval()

def _save_checkpoint(self, epoch, score):
model_dict = self.model.module.state_dict() if self.world_size > 1 else self.model.state_dict()
state_dict = {'epoch': epoch,
'optimizer': self.optimizer.state_dict(),
'model': self.model.module.state_dict()}
'scheduler': self.scheduler.state_dict(),
'model': model_dict}

torch.save(state_dict, os.path.join(self.checkpoint_path, f'model_{str(epoch).zfill(4)}.tar'))

Expand All @@ -101,33 +98,38 @@ def _resume_checkpoint(self):

self.start_epoch = checkpoint['epoch'] + 1
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.model.module.load_state_dict(checkpoint['model'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
if self.world_size > 1:
self.model.module.load_state_dict(checkpoint['model'])
else:
self.model.load_state_dict(checkpoint['model'])

def _train_epoch(self, epoch):
total_loss = 0
train_bar = tqdm(self.train_dataloader, ncols=120)
train_bar = tqdm(self.train_dataloader, ncols=110)

for step, (mixture, target) in enumerate(train_bar, 1):
mixture = mixture.to(self.device)
target = target.to(self.device)

esti_tagt = self.model(mixture)

loss = self.loss_func(esti_tagt, target)
loss = reduce_value(loss)
if self.world_size > 1:
loss = reduce_value(loss)
total_loss += loss.item()

train_bar.desc = ' train[{}/{}][{}]'.format(
epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M"))

train_bar.postfix = 'train_loss={:.3f}'.format(total_loss / step)
self.train_bar.postfix = 'train_loss={:.2f}'.format(total_loss / step)

self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value)
self.optimizer.step()

if self.device != torch.device("cpu"):
if self.world_size > 1 and (self.device != torch.device("cpu")):
torch.cuda.synchronize(self.device)

if self.rank == 0:
Expand All @@ -140,43 +142,44 @@ def _validation_epoch(self, epoch):
total_loss = 0
total_pesq_score = 0

validation_bar = tqdm(self.validation_dataloader, ncols=132)
validation_bar = tqdm(self.validation_dataloader, ncols=123)
for step, (mixture, target) in enumerate(validation_bar, 1):
mixture = mixture.to(self.device)
target = target.to(self.device)

esti_tagt = self.model(mixture)

loss = self.loss_func(esti_tagt, target)
loss = reduce_value(loss)
if self.world_size > 1:
loss = reduce_value(loss)
total_loss += loss.item()

enhanced = torch.istft(esti_tagt[..., 0] + 1j*esti_tagt[..., 1], **self.config['FFT'], window=torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device))
clean = torch.istft(target[..., 0] + 1j*target[..., 1], **self.config['FFT'], window=torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device))
enhanced = torch.istft(esti_tagt[..., 0] + 1j*esti_tagt[..., 1], **self.config['FFT'], window=torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device)).detach().cpu().numpy()
clean = torch.istft(target[..., 0] + 1j*target[..., 1], **self.config['FFT'], window=torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device)).cpu().numpy()

enhanced = enhanced.squeeze().cpu().numpy()
clean = clean.squeeze().cpu().numpy()

pesq_score = pesq(16000, clean, enhanced, 'wb')
pesq_score = reduce_value(torch.tensor(pesq_score, device=self.device))
pesq_score_batch = Parallel(n_jobs=-1)(
delayed(pesq)(16000, c, e, 'wb') for c, e in zip(clean, enhanced))
pesq_score = torch.tensor(pesq_score_batch, device=self.device).mean()
if self.world_size > 1:
pesq_score = reduce_value(pesq_score)
total_pesq_score += pesq_score
if self.args==0 and step <= 3:

if self.rank == 0 and step <= 3:
sf.write(os.path.join(self.sample_path,
'{}_enhanced_epoch{}_pesq={:.3f}.wav'.format(step, epoch, pesq_score)),
enhanced, 16000)
'{}_enhanced_epoch{}_pesq={:.3f}.wav'.format(step, epoch, pesq_score_batch[0])),
enhanced[0], 16000)
sf.write(os.path.join(self.sample_path,
'{}_clean.wav'.format(step)),
clean, 16000)
clean[0], 16000)

validation_bar.desc = 'validate[{}/{}][{}]'.format(
epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M"))

validation_bar.postfix = 'valid_loss={:.3f}, pesq={:.4f}'.format(
validation_bar.postfix = 'valid_loss={:.2f}, pesq={:.4f}'.format(
total_loss / step, total_pesq_score / step)


if self.device != torch.device("cpu"):
if (self.world_size > 1) and (self.device != torch.device("cpu")):
torch.cuda.synchronize(self.device)

if self.rank == 0:
Expand All @@ -188,26 +191,20 @@ def _validation_epoch(self, epoch):


def train(self):
if self.rank == 0:
timestamp_txt = os.path.join(self.exp_path, 'timestamp.txt')
mode = 'a' if os.path.exists(timestamp_txt) else 'w'
with open(timestamp_txt, mode) as f:
f.write('[{}] start for {} epochs\n'.format(
datetime.now().strftime("%Y-%m-%d-%H:%M"), self.epochs))

if self.resume:
self._resume_checkpoint()

for epoch in range(self.start_epoch, self.epochs + self.start_epoch):
self.train_sampler.set_epoch(epoch)
if self.train_sampler is not None:
self.train_sampler.set_epoch(epoch)

self._set_train_mode()
self._train_epoch(epoch)

self._set_eval_mode()
valid_loss, score = self._validation_epoch(epoch)

self.scheduler.step(valid_loss)
self.scheduler.step()

if (self.rank == 0) and (epoch % self.save_checkpoint_interval == 0):
self._save_checkpoint(epoch, score)
Expand All @@ -218,8 +215,4 @@ def train(self):
'best_model_{}.tar'.format(str(self.state_dict_best['epoch']).zfill(4))))

print('------------Training for {} epochs has done!------------'.format(self.epochs))

with open(timestamp_txt, 'a') as f:
f.write('[{}] end\n'.format(datetime.now().strftime("%Y-%m-%d-%H:%M")))



0 comments on commit 13b22dd

Please sign in to comment.