From f2701f6b2b19d9f7741d6fce698035af7fb3d974 Mon Sep 17 00:00:00 2001 From: Rong Xiaobin Date: Sun, 26 Nov 2023 20:36:37 +0800 Subject: [PATCH] Add files via upload --- README.md | 41 ++++++++ config.toml | 51 ++++++++++ config.yaml | 11 +++ datasets.py | 75 +++++++++++++++ distributed_utils.py | 70 ++++++++++++++ infer_folder.py | 52 ++++++++++ infer_loader.py | 98 +++++++++++++++++++ loss_factory.py | 141 +++++++++++++++++++++++++++ model.py | 131 +++++++++++++++++++++++++ requirements.txt | 13 +++ score_utils.py | 20 ++++ train.py | 77 +++++++++++++++ train_sg.py | 52 ++++++++++ trainer.py | 224 +++++++++++++++++++++++++++++++++++++++++++ trainer_sg.py | 202 ++++++++++++++++++++++++++++++++++++++ 15 files changed, 1258 insertions(+) create mode 100644 README.md create mode 100644 config.toml create mode 100644 config.yaml create mode 100644 datasets.py create mode 100644 distributed_utils.py create mode 100644 infer_folder.py create mode 100644 infer_loader.py create mode 100644 loss_factory.py create mode 100644 model.py create mode 100644 requirements.txt create mode 100644 score_utils.py create mode 100644 train.py create mode 100644 train_sg.py create mode 100644 trainer.py create mode 100644 trainer_sg.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..0733a3a --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +# A code template for training DNN-based speech enhancement models. +A training code template is highly valuable for deep learning engineers as it can significantly enhance their work efficiency. Despite different individuals have varying coding styles, some are excellent while others may not be as good. My philosophy is to prioritize simplicity. In this context, I am sharing a practical organizational structure for training code files in speech enhancement (SE). The primary focus is on keeping it concise and intuitive rather than aiming for comprehensiveness. + +## File Specification +For training: +* `config.toml`: Specifies the training configurations. +* `datasets.py`: Provides the dataset class for the dataloader. +* `distributed_utils.py`: Assists with Distributed Data Parallel (DDP) training. +* `model.py`: Defines the model. +* `loss_factory.py`: Provides various useful loss functions in SE. +* `train_sg.py`: Conducts the training process for a single GPU machine. +* `train.py`: Conducts the training process for multiple GPUs. +* `trainer_sg.py`: Encapsulates various functions during training for a single GPU machine. +* `trainer.py`: Encapsulates various functions during training for multiple GPUs. + +For evaluation: +* `config.yaml`: Specifies evaluation paths. +* `infer_folder.py`: Conducts evaluation on a folder of WAV files. +* `infer_loader.py`: Conducts evaluation using a dataloader. +* `score_utils.py`: Provides calculations for various metrics. + +## Usage +When starting a new SE project, you should follow these steps: +1. Modify `datasets.py`; +2. Define your own `model.py`; +3. Modify the `config.toml` to match your training setup; +4. Select a loss function in `loss_factory.py`, or create a new one if needed; +5. Probably do not need to modify `trainer.py` or `trainer_sg.py`; +6. Run the `train.py` or `train_sg.py` based on the number of available GPUs. +7. Before evaluation, remember to modify `config.yaml` to ensure that the paths are correctly configured. + +## Note +The code is originally intended for Linux systems, and if you attempt to adapt it to the Windows platform, you may encounter certain issues: +* Incompatibility of paths: The file paths used in Linux systems may not be compatible with the file paths in Windows. + +* Challenges in installing the pesq package: The process of installing the pesq package on Windows may not be straightforward and may require additional steps or configurations. + +Please keep these considerations in mind when working with the code on the Windows platform. + +## Acknowledgement +This code template heavily borrows from the excellent [Sheffield_Clarity_CEC1_Entry](https://github.com/TuZehai/Sheffield_Clarity_CEC1_Entry) reposity in many aspects. \ No newline at end of file diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..ae4a29e --- /dev/null +++ b/config.toml @@ -0,0 +1,51 @@ +[network_config] + + +[DDP] +world_size = 2 # number of available gpus + +[optimizer] +lr = 1e-3 + +[loss] +loss_func = 'hybrid' + +[listener] +listener_sr = 16000 + +[FFT] +n_fft = 512 +hop_length = 256 +win_length = 512 + +[train_dataset] +train_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/train' +shuffle = false +num_tot = 0 +wav_len = 0 + +[train_dataloader] +batch_size = 16 +num_workers = 4 +drop_last = true +pin_memory = true + +[validation_dataset] +train_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/dev' +shuffle = false +num_tot = 0 +wav_len = 0 + +[validation_dataloader] +batch_size = 1 +num_workers = 4 +pin_memory = true + +[trainer] +epochs = 120 +save_checkpoint_interval = 1 +clip_grad_norm_value = 3.0 +exp_path = '/data/ssd0/xiaobin.rong/project_se/DNS3/exp_dpcrn' +resume = false +resume_datetime = '' +resume_step = 0 diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..45b1d91 --- /dev/null +++ b/config.yaml @@ -0,0 +1,11 @@ +path: + exp_folder: ${network.root}/${network.exp_name}/${network.cpt_name}/enhanced # where enhanced speech store + +network: + root: /data/ssd0/xiaobin.rong/project_se/DNS3 + exp_name: exp_dpcrn_2023-08-02-11h32m + cpt_name: model_0082 + checkpoint: ${network.exp_path}/checkpoints/${network.cpt_name}.tar + exp_path: ${network.root}/${network.exp_name} + cfg_toml: ${network.exp_path}/config.toml + diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..f1f15be --- /dev/null +++ b/datasets.py @@ -0,0 +1,75 @@ +import os +import toml +import random +import torch +import pandas as pd +import soundfile as sf +from torch.utils import data + + +class MyDataset(data.Dataset): + def __init__(self, train_folder, shuffle, num_tot, wav_len=0, n_fft=512, hop_length=256, win_length=512): + super().__init__() + ### We store the noisy-clean pairs in the same folder, and use CSV file to manage all the WAV files. + self.file_name = pd.read_csv(os.path.join(train_folder, 'INFO.csv'))['file_name'].to_list() + + if shuffle: + random.seed(7) + random.shuffle(self.file_name) + + if num_tot != 0: + self.file_name = self.file_name[: num_tot] + + self.train_folder = train_folder + self.wav_len = wav_len + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + + def __getitem__(self, idx): + noisy, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_noisy.wav'), dtype="float32") + clean, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_clean.wav'), dtype="float32") + + noisy = torch.tensor(noisy) + clean = torch.tensor(clean) + + if self.wav_len != 0: + start = random.choice(range(len(clean) - self.wav_len * fs)) + noisy = noisy[start: start + self.wav_len*fs] + clean = clean[start: start + self.wav_len*fs] + + noisy = torch.stft(noisy, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False) + clean = torch.stft(clean, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False) + + return noisy, clean + + def __len__(self): + return len(self.file_name) + + +if __name__=='__main__': + from tqdm import tqdm + config = toml.load('config.toml') + + device = torch.device('cuda') + + train_dataset = MyDataset(**config['train_dataset'], **config['FFT']) + train_dataloader = data.DataLoader(train_dataset, **config['train_dataloader']) + + validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT']) + validation_dataloader = data.DataLoader(validation_dataset, **config['validation_dataloader']) + + print(len(train_dataloader), len(validation_dataloader)) + + for noisy, clean in tqdm(train_dataloader): + print(noisy.shape, clean.shape) + break + # pass + + for noisy, clean in tqdm(validation_dataloader): + print(noisy.shape, clean.shape) + break + + + diff --git a/distributed_utils.py b/distributed_utils.py new file mode 100644 index 0000000..5b29de1 --- /dev/null +++ b/distributed_utils.py @@ -0,0 +1,70 @@ +import os + +import torch +import torch.distributed as dist + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' # Communication backend, NVIDIA GPUs are recommended to use NCCL. + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + dist.barrier() + + +def cleanup(): + dist.destroy_process_group() + + +def is_dist_avail_and_initialized(): + """ Check if distributed environment is supported. """ + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def reduce_value(value, average=True): + world_size = get_world_size() + if world_size < 2: # for single GPU + return value + + with torch.no_grad(): + dist.all_reduce(value) + if average: + value /= world_size + + return value diff --git a/infer_folder.py b/infer_folder.py new file mode 100644 index 0000000..d12652a --- /dev/null +++ b/infer_folder.py @@ -0,0 +1,52 @@ +""" +conduct evaluation on a folder of WAV files, with computing SISNR, PESQ, ESTOI, and DNSMOS. +""" +import os +os.environ["CUDA_VISIBLE_DEVICES"]="0" +import toml +import torch +from tqdm import tqdm +import soundfile as sf +from omegaconf import OmegaConf +from model import DPCRN + + +cfg_yaml = OmegaConf.load('config.yaml') +test_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/blind_test_set/dns-challenge-3-final-evaluation/wideband_16kHz/noisy_clips_wb_16kHz/' +test_wavnames = list(filter(lambda x: x.endswith("wav"), os.listdir(test_folder))) + +cfg_toml = toml.load(cfg_yaml.network.cfg_toml) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +netout_folder = f'{cfg_yaml.path.exp_folder}' +os.makedirs(netout_folder, exist_ok=True) + +### load model +model = DPCRN(**cfg_toml['network_config']) +model.to(device) +checkpoint = torch.load(cfg_yaml.network.checkpoint, map_location=device) +model.load_state_dict(checkpoint['model']) +model.eval() + +for param in model.parameters(): + param.requires_grad = False + +### compute SISNR, PESQ and ESTOI +with torch.no_grad(): + for name in tqdm(test_wavnames): + noisy, fs = sf.read(os.path.join(test_folder, name), dtype="float32") + noisy = torch.stft(torch.from_numpy(noisy), **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5)) + noisy = noisy.to(device) + + estimate= model(noisy[None, ...]) # (B,F,T,2) + + enhanced = torch.istft(estimate[..., 0] + 1j*estimate[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device)) + out = enhanced.cpu().detach().numpy().squeeze() + + sf.write(os.path.join(netout_folder, name[:-4]+'_enh.wav'), out, fs) + +### compute DNSMOS +os.chdir('DNSMOS') +out_dir = os.path.join(netout_folder, 'dnsmos_enhanced_p808.csv') +os.system(f'python dnsmos_local_p808.py -t {netout_folder} -o {out_dir}') + diff --git a/infer_loader.py b/infer_loader.py new file mode 100644 index 0000000..157612c --- /dev/null +++ b/infer_loader.py @@ -0,0 +1,98 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"]="1" + +import toml +import torch +import pandas as pd +import soundfile as sf +from tqdm import tqdm +from pesq import pesq +from pystoi import stoi +from score_utils import sisnr +from omegaconf import OmegaConf +from datasets import MyDataset +from model import DPCRN + + +@torch.no_grad() +def infer(cfg_yaml): + + save_wavs = input('>>> Save wavs? (y/n) ') + if save_wavs == 'y': + mark = input('>>> Please enter a tag for the saved wav names: ') + + cfg_toml = toml.load(cfg_yaml.network.cfg_toml) + cfg_toml['validation_dataset']['train_folder'] = '/data/ssd0/xiaobin.rong/Datasets/DNS3/test/' + cfg_toml['validation_dataset']['num_tot'] = 0 # all utterances + cfg_toml['validation_dataset']['wav_len'] = 0 # full wav length + cfg_toml['validation_dataloader']['batch_size'] = 1 # one utterence once + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + netout_folder = f'{cfg_yaml.path.exp_folder}' + os.system(f'rm {netout_folder}/*.wav') + os.makedirs(netout_folder, exist_ok=True) + + validation_dataset = MyDataset(**cfg_toml['validation_dataset']) + validation_filename = validation_dataset.file_name + + validation_dataloader = torch.utils.data.DataLoader(validation_dataset, **cfg_toml['validation_dataloader']) + + ### load model + model = DPCRN(**cfg_toml['network_config']) + model.to(device) + checkpoint = torch.load(cfg_yaml.network.checkpoint, map_location=device) + model.load_state_dict(checkpoint['model']) + model.eval() + + for param in model.parameters(): + param.requires_grad = False + + ### compute SISNR, PESQ, and ESTOI + INFO1 = [] + INFO = pd.read_csv(os.path.join(cfg_toml['validation_dataset']['train_folder'], 'INFO.csv')) + for step, (mixture, target) in enumerate(tqdm(validation_dataloader)): + + mixture = mixture.to(device) + target = target.to(device) + + estimate= model(mixture) # [B, F, T, 2] + + enhanced = torch.istft(estimate[..., 0] + 1j*estimate[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device)) + clean = torch.istft(target[..., 0] + 1j*target[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device)) + + out = enhanced.cpu().detach().numpy().squeeze() + clean = clean.cpu().detach().numpy().squeeze() + + # out = torch.clamp(out, -1, 1) + # out = out / out.max() * 0.5 + + sisnr_score = sisnr(out, clean) + pesq_score = pesq(16000, clean, out, 'wb') + estoi_score = stoi(clean, out, 16000, extended=True) + + ## save wavs + if save_wavs == 'y': + save_name = "{}_{}_{:.2f}_{:.2f}_{:.2f}.wav".format(validation_filename[step], mark, sisnr_score, pesq_score, estoi_score) + + sf.write( + os.path.join(netout_folder, save_name), out, cfg_toml['listener']['listener_sr']) + + ## save infos + file_name = validation_filename[step] + INFO1.append([file_name, sisnr_score, pesq_score, estoi_score]) + + INFO1 = pd.DataFrame(INFO1, columns=['file_name', 'sisnr', 'pesq', 'estoi']) + INFO2 = pd.merge(INFO, INFO1) + INFO2.to_csv(os.path.join(netout_folder, 'INFO2.csv'), index=None) + + ### compute DNSMOS + os.chdir('DNSMOS') + out_dir = os.path.join(netout_folder, 'dnsmos_enhanced_p808.csv') + os.system(f'python dnsmos_local_p808.py -t {netout_folder} -o {out_dir}') + + +if __name__ == "__main__": + cfg_yaml = OmegaConf.load('config.yaml') + infer(cfg_yaml) + diff --git a/loss_factory.py b/loss_factory.py new file mode 100644 index 0000000..c91c259 --- /dev/null +++ b/loss_factory.py @@ -0,0 +1,141 @@ +import toml +import torch +import torch.nn as nn +from torch_stoi import NegSTOILoss + + +config = toml.load('config.toml') + + +class loss_mse(nn.Module): + def __init__(self): + super(loss_mse, self).__init__() + self.window = torch.hann_window(config['FFT']['win_length']).pow(0.5) + self.mse_loss = nn.MSELoss(reduction='mean') + + def forward(self, est, clean): + """ inputs: spectrograms, (B,F,T,2) """ + data_len = min(est.shape[-1], clean.shape[-1]) + est = est[..., :data_len] + clean = clean[..., :data_len] + + est_stft = torch.stft(est, **config['FFT'], center=True, window=self.window.to(est.device), return_complex=False) + clean_stft = torch.stft(clean, **config['FFT'], center=True, window=self.window.to(clean.device), return_complex=False) + est_stft_real, est_stft_imag = est_stft[:,:,:,0], est_stft[:,:,:,1] + clean_stft_real, clean_stft_imag = clean_stft[:,:,:,0], clean_stft[:,:,:,1] + est_mag = torch.sqrt(est_stft_real**2 + est_stft_imag**2 + 1e-12) + clean_mag = torch.sqrt(clean_stft_real**2 + clean_stft_imag**2 + 1e-12) + est_real_c = est_stft_real / (est_mag**(0.7)) + est_imag_c = est_stft_imag / (est_mag**(0.7)) + clean_real_c = clean_stft_real / (clean_mag**(0.7)) + clean_imag_c = clean_stft_imag / (clean_mag**(0.7)) + + loss = 0.7 * self.mse_loss(est_mag**(0.3), clean_mag**(0.3)) + \ + 0.3 * (self.mse_loss(est_real_c, clean_real_c) + \ + self.mse_loss(est_imag_c, clean_imag_c)) + + return loss + + +class loss_sisnr(nn.Module): + def __init__(self): + super(loss_sisnr, self).__init__() + + def forward(self, est, clean): + """ inputs: waveform, (B,...,T) """ + data_len = min(est.shape[-1], clean.shape[-1]) + est = est[..., :data_len] + clean = clean[...,:data_len] + est = est - torch.mean(est, dim=-1, keepdim=True) + clean = clean - torch.mean(clean, dim=-1, keepdim=True) + + target = torch.sum(est * clean, 1, keepdim=True) * clean / \ + torch.sum(clean**2 + 1e-8, 1, keepdim=True) + noise = est - target + sisnr = 10*torch.log10((torch.sum(target**2, 1) + 1e-8)/(torch.sum(noise**2, 1) + 1e-8)) + est_std = torch.std(est, dim=1) + clean_std = torch.std(clean, dim=1) + + com_factor = torch.minimum((est_std + 1e-8) / (clean_std + 1e-8), + (clean_std + 1e-8) / (est_std + 1e-8)) + + return -torch.mean(sisnr * com_factor) + + +class loss_stoi(torch.nn.Module): + def __init__(self, sample_rate): + super(loss_stoi, self).__init__() + self.NegSTOI = NegSTOILoss(sample_rate=sample_rate) + + def forward(self, est, clean): + """ inputs: waveform, (B,...,T) """ + data_len = min(est.shape[-1], clean.shape[-1]) + est = est[..., : data_len] + clean = clean[...,: data_len] + + return self.NegSTOI(est, clean).mean() + + +class loss_wavmag(nn.Module): + def __init__(self): + super(loss_wavmag, self).__init__() + + def forward(self, est_stft, clean_stft, alpha=10): + """ inputs: spectrograms, (B,F,T,2) """ + device = est_stft.device + + est_stft = est_stft[..., 0] + 1j*est_stft[..., 1] + clean_stft = clean_stft[..., 0] + 1j*clean_stft[..., 1] + + estimated = torch.istft(est_stft, **config['FFT'], window=torch.hann_window(512).pow(0.5).to(device)) + clean = torch.istft(clean_stft, **config['FFT'], window=torch.hann_window(512).pow(0.5).to(device)) + + loss_wav = torch.norm((estimated - clean), p=1) / clean.numel() * 100 + loss_mag = torch.norm(abs(est_stft) - abs(clean_stft), p=1) / clean_stft.numel() * 100 + return alpha*loss_wav + loss_mag + + +class loss_hybrid(nn.Module): + def __init__(self): + super().__init__() + self.window = torch.hann_window(config['FFT']['win_length']).pow(0.5) + + def forward(self, pred_stft, true_stft): + """ inputs: spectrograms, (B,F,T,2) """ + device = pred_stft.device + + pred_stft_real, pred_stft_imag = pred_stft[:,:,:,0], pred_stft[:,:,:,1] + true_stft_real, true_stft_imag = true_stft[:,:,:,0], true_stft[:,:,:,1] + pred_mag = torch.sqrt(pred_stft_real**2 + pred_stft_imag**2 + 1e-12) + true_mag = torch.sqrt(true_stft_real**2 + true_stft_imag**2 + 1e-12) + pred_real_c = pred_stft_real / (pred_mag**(0.7)) + pred_imag_c = pred_stft_imag / (pred_mag**(0.7)) + true_real_c = true_stft_real / (true_mag**(0.7)) + true_imag_c = true_stft_imag / (true_mag**(0.7)) + real_loss = torch.mean((pred_real_c - true_real_c)**2) + imag_loss = torch.mean((pred_imag_c - true_imag_c)**2) + mag_loss = torch.mean((pred_mag**(0.3)-true_mag**(0.3))**2) + + + y_pred = torch.istft(pred_stft_real+1j*pred_stft_imag, **config['FFT'], window=self.window.to(device)) + y_true = torch.istft(true_stft_real+1j*true_stft_imag, **config['FFT'], window=self.window.to(device)) + y_true = torch.sum(y_true * y_pred, dim=-1, keepdim=True) * y_true / (torch.sum(torch.square(y_true),dim=-1,keepdim=True) + 1e-8) + sisnr = - torch.log10(torch.sum(torch.square(y_true),dim=-1,keepdim=True) / torch.sum(torch.square(y_pred - y_true),dim=-1,keepdim=True) + 1e-8).mean() + + return 30*(real_loss + imag_loss) + 70*mag_loss + sisnr + + + + +if __name__=='__main__': + a = torch.randn(2,10000) + b = torch.randn(2, 9990) + loss_func = loss_sisnr() + loss = loss_func(a,b) + print(loss) + + S_ = torch.randn(3, 257, 91, 2) + S = torch.randn(3, 257, 91, 2) + loss_func = loss_hybrid() + loss = loss_func(S_, S) + print(loss) \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..a806750 --- /dev/null +++ b/model.py @@ -0,0 +1,131 @@ +""" +1.74 GMac 787.15 k +""" +import torch +import torch.nn as nn + + +class DPRNN(nn.Module): + def __init__(self, numUnits, width, channel, **kwargs): + super(DPRNN, self).__init__(**kwargs) + self.numUnits = numUnits + self.width = width + self.channel = channel + + self.intra_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits//2, batch_first = True, bidirectional = True) + self.intra_fc = nn.Linear(self.numUnits, self.numUnits) + self.intra_ln = nn.LayerNorm((width, numUnits), eps=1e-8) + + self.inter_rnn = nn.LSTM(input_size = self.numUnits, hidden_size = self.numUnits, batch_first = True, bidirectional = False) + self.inter_fc = nn.Linear(self.numUnits, self.numUnits) + self.inter_ln = nn.LayerNorm((width, numUnits), eps=1e-8) + + def forward(self,x): + # x: (B, C, T, F) + ## Intra RNN + x = x.permute(0, 2, 3, 1) # (B,T,F,C) + intra_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) # (B*T,F,C) + intra_x = self.intra_rnn(intra_x)[0] # (B*T,F,C) + intra_x = self.intra_fc(intra_x) # (B*T,F,C) + intra_x = intra_x.reshape(x.shape[0], -1, self.width, self.channel) # (B,T,F,C) + intra_x = self.intra_ln(intra_x) + intra_out = torch.add(x, intra_x) + + ## Inter RNN + x = intra_out.permute(0,2,1,3) # (B,F,T,C) + inter_x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) + inter_x = self.inter_rnn(inter_x)[0] # (B*F,T,C) + inter_x = self.inter_fc(inter_x) # (B*F,T,C) + inter_x = inter_x.reshape(x.shape[0], self.width, -1, self.channel) # (B,F,T,C) + inter_x = inter_x.permute(0,2,1,3) # (B,T,F,C) + inter_x = self.inter_ln(inter_x) + inter_out = torch.add(intra_out, inter_x) + + dual_out = inter_out.permute(0,3,1,2) # (B,C,T,F) + + return dual_out + + +class DPCRN(nn.Module): + def __init__(self): + super().__init__() + self.en_conv1 = nn.Sequential(nn.ConstantPad2d([2,2,1,0], 0), + nn.Conv2d(2, 32, (2,5), (1,2)), + nn.BatchNorm2d(32), + nn.PReLU()) + self.en_conv2 = nn.Sequential(nn.ConstantPad2d([1,1,1,0], 0), + nn.Conv2d(32, 32, (2,3), (1,2)), + nn.BatchNorm2d(32), + nn.PReLU()) + self.en_conv3 = nn.Sequential(nn.ConstantPad2d([1,1,1,0], 0), + nn.Conv2d(32, 32, (2,3), (1,2)), + nn.BatchNorm2d(32), + nn.PReLU()) + self.en_conv4 = nn.Sequential(nn.ConstantPad2d([1,1,1,0], 0), + nn.Conv2d(32, 64, (2,3), (1,1)), + nn.BatchNorm2d(64), + nn.PReLU()) + self.en_conv5 = nn.Sequential(nn.ConstantPad2d([1,1,1,0], 0), + nn.Conv2d(64, 128, (2,3), (1,1)), + nn.BatchNorm2d(128), + nn.PReLU()) + self.dprnn1 = DPRNN(128, 33, 128) + self.dprnn2 = DPRNN(128, 33, 128) + + self.de_conv5 = nn.Sequential(nn.ConvTranspose2d(256, 64, (2,3), (1,1)), + nn.BatchNorm2d(64), + nn.PReLU()) + self.de_conv4 = nn.Sequential(nn.ConvTranspose2d(128, 32, (2,3), (1,1)), + nn.BatchNorm2d(32), + nn.PReLU()) + self.de_conv3 = nn.Sequential(nn.ConvTranspose2d(64, 32, (2,3), (1,2)), + nn.BatchNorm2d(32), + nn.PReLU()) + self.de_conv2 = nn.Sequential(nn.ConvTranspose2d(64, 32, (2,3), (1,2)), + nn.BatchNorm2d(32), + nn.PReLU()) + self.de_conv1 = nn.Sequential(nn.ConvTranspose2d(64, 2, (2,5), (1,2)), + nn.BatchNorm2d(2)) + + def forward(self, x): + """ + x: (B,F,T,2) + """ + x_ref = x + x = x.permute(0, 3, 2, 1) # (B,C,T,F) + en_x1 = self.en_conv1(x) # ; print(en_x1.shape) + en_x2 = self.en_conv2(en_x1) # ; print(en_x2.shape) + en_x3 = self.en_conv3(en_x2) # ; print(en_x3.shape) + en_x4 = self.en_conv4(en_x3) # ; print(en_x4.shape) + en_x5 = self.en_conv5(en_x4) # ; print(en_x5.shape) + + en_xr = self.dprnn1(en_x5) # ; print(en_xr.shape) + en_xr = self.dprnn2(en_xr) # ; print(en_xr.shape) + + de_x5 = self.de_conv5(torch.cat([en_x5, en_xr], dim=1))[...,:-1,:-2] #; print(de_x5.shape) + de_x4 = self.de_conv4(torch.cat([en_x4, de_x5], dim=1))[...,:-1,:-2] #; print(de_x4.shape) + de_x3 = self.de_conv3(torch.cat([en_x3, de_x4], dim=1))[...,:-1,:-2] #; print(de_x3.shape) + de_x2 = self.de_conv2(torch.cat([en_x2, de_x3], dim=1))[...,:-1,:-2] #; print(de_x2.shape) + de_x1 = self.de_conv1(torch.cat([en_x1, de_x2], dim=1))[...,:-1,:-4] #; print(de_x1.shape) + + m = de_x1.permute(0,3,2,1) + + s_real = x_ref[...,0] * m[...,0] - x_ref[...,1] * m[...,1] + s_imag = x_ref[...,1] * m[...,0] + x_ref[...,0] * m[...,1] + s = torch.stack([s_real, s_imag], dim=-1) # (B,F,T,2) + + return s + + +if __name__ == "__main__": + model = DPCRN().cuda() + + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (257, 63, 2), as_strings=True, + print_per_layer_stat=False, verbose=True) + print(flops, params) + + model = model.cpu().eval() + x = torch.randn(1, 257, 63, 2) + y = model(x) + print(y.shape) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8f7fbba --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +einops==0.7.0 +librosa==0.10.1 +numpy==1.24.4 +omegaconf==2.3.0 +pandas==2.1.3 +pesq==0.0.4 +ptflops==0.7 +pystoi==0.3.3 +soundfile==0.12.1 +toml==0.10.2 +torch==1.11.0 +torch_stoi==0.1.2 +tqdm==4.66.1 diff --git a/score_utils.py b/score_utils.py new file mode 100644 index 0000000..d4cbc65 --- /dev/null +++ b/score_utils.py @@ -0,0 +1,20 @@ +import numpy as np + +def sisnr(esti, tagt): + """ for single wav """ + esti = esti - esti.mean() + tagt = tagt - tagt.mean() + + a = np.sum(esti * tagt) / np.sum(tagt**2 + 1e-8) + e_tagt = a * tagt + e_res = esti - e_tagt + + return 10*np.log10((np.sum(e_tagt**2)+1e-8) / (np.sum(e_res**2)+1e-8)) + + + + +if __name__ == "__main__": + x = np.random.randn(100) + s = np.random.randn(100) + print(sisnr(x, s)) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..efc1a3d --- /dev/null +++ b/train.py @@ -0,0 +1,77 @@ +""" +multiple GPUs version, using DDP training. +""" +import os +os.environ["CUDA_VISIBLE_DEVICES"]="0, 1" +import toml +import torch +import argparse +import torch.distributed as dist + +from trainer import Trainer +from model import DPCRN +from datasets import MyDataset +from loss_factory import loss_wavmag, loss_mse, loss_hybrid + +seed = 0 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + + +def run(rank, config, args): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12354' + dist.init_process_group("nccl", rank=rank, world_size=args.world_size) + torch.cuda.set_device(rank) + dist.barrier() + + args.rank = rank + args.device = torch.device(rank) + + train_dataset = MyDataset(**config['train_dataset'], **config['FFT']) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, sampler=train_sampler, + **config['train_dataloader']) + + validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT']) + validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset) + validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, sampler=validation_sampler, + **config['validation_dataloader']) + + model = DPCRN(**config['network_config']) + model.to(args.device) + + # convert to DDP model + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + optimizer = torch.optim.Adam(params=model.parameters(), lr=config['optimizer']['lr']) + + + if config['loss']['loss_func'] == 'wav_mag': + loss = loss_wavmag() + elif config['loss']['loss_func'] == 'mse': + loss = loss_mse() + elif config['loss']['loss_func'] == 'hybrid': + loss = loss_hybrid() + else: + raise(NotImplementedError) + + trainer = Trainer(config=config, model=model,optimizer=optimizer, loss_func=loss, + train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, + train_sampler=train_sampler, args=args) + + trainer.train() + + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-C', '--config', default='config.toml') + + args = parser.parse_args() + + config = toml.load(args.config) + args.world_size = config['DDP']['world_size'] + torch.multiprocessing.spawn( + run, args=(config, args,), nprocs=args.world_size, join=True) + diff --git a/train_sg.py b/train_sg.py new file mode 100644 index 0000000..df9b7fb --- /dev/null +++ b/train_sg.py @@ -0,0 +1,52 @@ +""" +single GPU version. +""" +import os +os.environ["CUDA_VISIBLE_DEVICES"]="0" +import toml +import torch + +from trainer_sg import Trainer +from model import DPCRN +from datasets import MyDataset +from loss_factory import loss_wavmag, loss_mse, loss_hybrid + +seed = 0 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +# torch.cuda.manual_seed_all(seed) + +def run(config, device): + + train_dataset = MyDataset(**config['train_dataset']) + train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, **config['train_dataloader']) + + validation_dataset = MyDataset(**config['validation_dataset']) + validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, **config['validation_dataloader']) + + model = DPCRN(**config['network_config']) + model.to(device) + + + optimizer = torch.optim.Adam(params=model.parameters(), lr=config['optimizer']['lr']) + + if config['loss']['loss_func'] == 'wav_mag': + loss = loss_wavmag() + elif config['loss']['loss_func'] == 'mse': + loss = loss_mse() + elif config['loss']['loss_func'] == 'hybrid': + loss = loss_hybrid() + else: + raise(NotImplementedError) + + trainer = Trainer(config=config, model=model,optimizer=optimizer, loss_func=loss, + train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, + device=device) + + trainer.train() + +if __name__ == '__main__': + device = torch.device("cuda") + config = toml.load('config.toml') + run(config, device) + diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..60da381 --- /dev/null +++ b/trainer.py @@ -0,0 +1,224 @@ +""" +multiple GPUs version, using DDP training. +""" +import os +import torch +import toml +from datetime import datetime +from tqdm import tqdm +from glob import glob +import soundfile as sf +from torch.utils.tensorboard import SummaryWriter +from pesq import pesq +from distributed_utils import reduce_value + + +class Trainer: + def __init__(self, config, model, optimizer, loss_func, + train_dataloader, validation_dataloader, train_sampler, args): + self.config = config + self.model = model + + self.optimizer = optimizer + + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, 'min', factor=0.5, patience=5,verbose=True) + + self.loss_func = loss_func + + self.train_dataloader = train_dataloader + self.validation_dataloader = validation_dataloader + + self.train_sampler = train_sampler + self.rank = args.rank + self.device = args.device + + ## training config + self.trainer_config = config['trainer'] + self.epochs = self.trainer_config['epochs'] + self.save_checkpoint_interval = self.trainer_config['save_checkpoint_interval'] + self.clip_grad_norm_value = self.trainer_config['clip_grad_norm_value'] + self.resume = self.trainer_config['resume'] + + if not self.resume: + self.exp_path = self.trainer_config['exp_path'] + '_' + datetime.now().strftime("%Y-%m-%d-%Hh%Mm") + + else: + self.exp_path = self.trainer_config['exp_path'] + '_' + self.trainer_config['resume_datetime'] + + self.log_path = os.path.join(self.exp_path, 'logs') + self.checkpoint_path = os.path.join(self.exp_path, 'checkpoints') + self.sample_path = os.path.join(self.exp_path, 'val_samples') + + os.makedirs(self.log_path, exist_ok=True) + os.makedirs(self.checkpoint_path, exist_ok=True) + os.makedirs(self.sample_path, exist_ok=True) + + ## save the config + if self.rank == 0: + with open( + os.path.join( + self.exp_path, 'config.toml'.format(datetime.now().strftime("%Y-%m-%d-%Hh%Mm"))), 'w') as f: + + toml.dump(config, f) + + self.writer = SummaryWriter(self.log_path) + + self.start_epoch = 1 + self.best_score = 0 + + if self.resume: + self._resume_checkpoint() + + self.sr = config['listener']['listener_sr'] + + self.loss_func = self.loss_func.to(self.device) + + + def _set_train_mode(self): + self.model.train() + + def _set_eval_mode(self): + self.model.eval() + + def _save_checkpoint(self, epoch, score): + state_dict = {'epoch': epoch, + 'optimizer': self.optimizer.state_dict(), + 'model': self.model.module.state_dict()} + + torch.save(state_dict, os.path.join(self.checkpoint_path, f'model_{str(epoch).zfill(4)}.tar')) + + if score > self.best_score: + self.state_dict_best = state_dict.copy() + self.best_score = score + + def _resume_checkpoint(self): + latest_checkpoints = sorted(glob(os.path.join(self.checkpoint_path, 'model_*.tar')))[-1] + + map_location = self.device + checkpoint = torch.load(latest_checkpoints, map_location=map_location) + + self.start_epoch = checkpoint['epoch'] + 1 + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.model.module.load_state_dict(checkpoint['model']) + + def _train_epoch(self, epoch): + total_loss = 0 + self.train_dataloader = tqdm(self.train_dataloader, ncols=110) + + for step, (mixture, target) in enumerate(self.train_dataloader, 1): + mixture = mixture.to(self.device) + target = target.to(self.device) + + esti_tagt = self.model(mixture) + + loss = self.loss_func(esti_tagt, target) + loss = reduce_value(loss) + total_loss += loss.item() + + self.train_dataloader.desc = ' train[{}/{}][{}]'.format( + epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M")) + + self.train_dataloader.postfix = 'loss={:.3f}'.format(total_loss / step) + + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value) + self.optimizer.step() + + if self.device != torch.device("cpu"): + torch.cuda.synchronize(self.device) + + if self.rank == 0: + self.writer.add_scalars('lr', {'lr': self.optimizer.param_groups[0]['lr']}, epoch) + self.writer.add_scalars('train_loss', {'train_loss': total_loss / step}, epoch) + + + @torch.no_grad() + def _validation_epoch(self, epoch): + total_loss = 0 + total_pesq_score = 0 + + self.validation_dataloader = tqdm(self.validation_dataloader, ncols=132) + for step, (mixture, target) in enumerate(self.validation_dataloader, 1): + mixture = mixture.to(self.device) + target = target.to(self.device) + + esti_tagt = self.model(mixture) + + loss = self.loss_func(esti_tagt, target) + loss = reduce_value(loss) + total_loss += loss.item() + + enhanced = torch.istft(esti_tagt[..., 0] + 1j*esti_tagt[..., 1], **self.config['FFT'], torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device)) + clean = torch.istft(target[..., 0] + 1j*target[..., 1], **self.config['FFT'], torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device)) + + enhanced = enhanced.squeeze().cpu().numpy() + clean = clean.squeeze().cpu().numpy() + + pesq_score = pesq(16000, clean, enhanced, 'wb') + pesq_score = reduce_value(torch.tensor(pesq_score, device=self.device)) + total_pesq_score += pesq_score + + if self.args==0 and step <= 3: + sf.write(os.path.join(self.sample_path, + '{}_enhanced_epoch{}_pesq={:.3f}.wav'.format(step, epoch, pesq_score)), + enhanced, 16000) + sf.write(os.path.join(self.sample_path, + '{}_clean.wav'.format(step)), + clean, 16000) + + self.validation_dataloader.desc = 'validate[{}/{}][{}]'.format( + epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M")) + + self.validation_dataloader.postfix = 'loss={:.2f}, pesq={:.4f}'.format( + total_loss / step, total_pesq_score / step) + + + if self.device != torch.device("cpu"): + torch.cuda.synchronize(self.device) + + if self.rank == 0: + self.writer.add_scalars( + 'val_loss', {'val_loss': total_loss / step, + 'pesq': total_pesq_score / step}, epoch) + + return total_loss / step, total_pesq_score / step + + + def train(self): + if self.rank == 0: + timestamp_txt = os.path.join(self.exp_path, 'timestamp.txt') + mode = 'a' if os.path.exists(timestamp_txt) else 'w' + with open(timestamp_txt, mode) as f: + f.write('[{}] start for {} epochs\n'.format( + datetime.now().strftime("%Y-%m-%d-%H:%M"), self.epochs)) + + if self.resume: + self._resume_checkpoint() + + for epoch in range(self.start_epoch, self.epochs + self.start_epoch): + self.train_sampler.set_epoch(epoch) + + self._set_train_mode() + self._train_epoch(epoch) + + self._set_eval_mode() + valid_loss, score = self._validation_epoch(epoch) + + self.scheduler.step(valid_loss) + + if (self.rank == 0) and (epoch % self.save_checkpoint_interval == 0): + self._save_checkpoint(epoch, score) + + if self.rank == 0: + torch.save(self.state_dict_best, + os.path.join(self.checkpoint_path, + 'best_model_{}.tar'.format(str(self.state_dict_best['epoch']).zfill(4)))) + + print('------------Training for {} epochs has done!------------'.format(self.epochs)) + + with open(timestamp_txt, 'a') as f: + f.write('[{}] end\n'.format(datetime.now().strftime("%Y-%m-%d-%H:%M"))) + + \ No newline at end of file diff --git a/trainer_sg.py b/trainer_sg.py new file mode 100644 index 0000000..2f0576d --- /dev/null +++ b/trainer_sg.py @@ -0,0 +1,202 @@ +""" +single GPU version. +""" +import os +import torch +import toml +from datetime import datetime +from tqdm import tqdm +from glob import glob +import soundfile as sf +from torch.utils.tensorboard import SummaryWriter +from pesq import pesq + + +class Trainer: + def __init__(self, config, model, optimizer, loss_func, + train_dataloader, validation_dataloader, device): + self.config = config + self.model = model + self.optimizer = optimizer + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, 'min', factor=0.5, patience=5,verbose=True) + self.loss_func = loss_func + + self.train_dataloader = train_dataloader + self.validation_dataloader = validation_dataloader + + self.device = device + + ## training config + self.trainer_config = config['trainer'] + self.epochs = self.trainer_config['epochs'] + self.save_checkpoint_interval = self.trainer_config['save_checkpoint_interval'] + self.clip_grad_norm_value = self.trainer_config['clip_grad_norm_value'] + self.resume = self.trainer_config['resume'] + + if not self.resume: + self.exp_path = self.trainer_config['exp_path'] + '_' + datetime.now().strftime("%Y-%m-%d-%Hh%Mm") + + else: + self.exp_path = self.trainer_config['exp_path'] + '_' + self.trainer_config['resume_datetime'] + + self.log_path = os.path.join(self.exp_path, 'logs') + self.checkpoint_path = os.path.join(self.exp_path, 'checkpoints') + self.sample_path = os.path.join(self.exp_path, 'val_samples') + + os.makedirs(self.log_path, exist_ok=True) + os.makedirs(self.checkpoint_path, exist_ok=True) + os.makedirs(self.sample_path, exist_ok=True) + + ## save the config + with open( + os.path.join( + self.exp_path, 'config.toml'.format(datetime.now().strftime("%Y-%m-%d-%Hh%Mm"))), 'w') as f: + + toml.dump(config, f) + + self.writer = SummaryWriter(self.log_path) + + self.start_epoch = 1 + self.best_score = 0 + + if self.resume: + self._resume_checkpoint() + + self.sr = config['listener']['listener_sr'] + + self.loss_func = self.loss_func.to(self.device) + + + def _set_train_mode(self): + self.model.train() + + def _set_eval_mode(self): + self.model.eval() + + def _save_checkpoint(self, epoch, score): + state_dict = {'epoch': epoch, + 'optimizer': self.optimizer.state_dict(), + 'model': self.model.state_dict()} + + torch.save(state_dict, os.path.join(self.checkpoint_path, f'model_{str(epoch).zfill(4)}.tar')) + + if score > self.best_score: + self.state_dict_best = state_dict.copy() + self.best_score = score + + def _resume_checkpoint(self): + latest_checkpoints = sorted(glob(os.path.join(self.checkpoint_path, 'model_*.tar')))[-1] + + map_location = self.device + checkpoint = torch.load(latest_checkpoints, map_location=map_location) + + self.start_epoch = checkpoint['epoch'] + 1 + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.model.load_state_dict(checkpoint['model']) + + def _train_epoch(self, epoch): + total_loss = 0 + self.train_dataloader = tqdm(self.train_dataloader, ncols=120) + + for step, (mixture, target) in enumerate(self.train_dataloader, 1): + mixture = mixture.to(self.device) + target = target.to(self.device) + + esti_tagt = self.model(mixture) + + loss = self.loss_func(esti_tagt, target) + total_loss += loss.item() + + self.train_dataloader.desc = ' train[{}/{}][{}]'.format( + epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M")) + + self.train_dataloader.postfix = 'train_loss={:.3f}'.format(total_loss / step) + + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value) + self.optimizer.step() + + self.writer.add_scalars('lr', {'lr': self.optimizer.param_groups[0]['lr']}, epoch) + self.writer.add_scalars('train_loss', {'train_loss': total_loss / step}, epoch) + + + @torch.no_grad() + def _validation_epoch(self, epoch): + total_loss = 0 + total_pesq_score = 0 + + self.validation_dataloader = tqdm(self.validation_dataloader, ncols=132) + for step, (mixture, target) in enumerate(self.validation_dataloader, 1): + mixture = mixture.to(self.device) + target = target.to(self.device) + + esti_tagt = self.model(mixture) + + loss = self.loss_func(esti_tagt, target) + total_loss += loss.item() + + enhanced = torch.istft(esti_tagt[..., 0] + 1j*esti_tagt[..., 1], **self.config['FFT'], torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device)) + clean = torch.istft(target[..., 0] + 1j*target[..., 1], **self.config['FFT'], torch.hann_window(self.config['FFT']['win_length']).pow(0.5).to(self.device)) + + enhanced = enhanced.squeeze().cpu().numpy() + clean = clean.squeeze().cpu().numpy() + + pesq_score = pesq(16000, clean, enhanced, 'wb') + total_pesq_score += pesq_score + + if step <= 3: + sf.write(os.path.join(self.sample_path, + '{}_enhanced_epoch{}_pesq={:.3f}.wav'.format(step, epoch, pesq_score)), + enhanced, 16000) + sf.write(os.path.join(self.sample_path, + '{}_clean.wav'.format(step)), + clean, 16000) + + self.validation_dataloader.desc = 'validate[{}/{}][{}]'.format( + epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M")) + + self.validation_dataloader.postfix = 'valid_loss={:.3f}, pesq={:.4f}'.format( + total_loss / step, total_pesq_score / step) + + self.writer.add_scalars( + 'val_loss', {'val_loss': total_loss / step, + 'pesq': total_pesq_score / step}, epoch) + + return total_loss / step, total_pesq_score / step + + + def train(self): + timestamp_txt = os.path.join(self.exp_path, 'timestamp.txt') + mode = 'a' if os.path.exists(timestamp_txt) else 'w' + with open(timestamp_txt, mode) as f: + f.write('[{}] start for {} epochs\n'.format( + datetime.now().strftime("%Y-%m-%d-%H:%M"), self.epochs)) + + if self.resume: + self._resume_checkpoint() + + for epoch in range(self.start_epoch, self.epochs + self.start_epoch): + + self._set_train_mode() + self._train_epoch(epoch) + + self._set_eval_mode() + valid_loss, pesq_score = self._validation_epoch(epoch) + + self.scheduler.step(valid_loss) + + if epoch % self.save_checkpoint_interval == 0: + self._save_checkpoint(epoch, pesq_score) + + torch.save(self.state_dict_best, + os.path.join(self.checkpoint_path, + 'best_model_{}.tar'.format(str(self.state_dict_best['epoch']).zfill(4)))) + + print('------------Training for {} epochs has done!------------'.format(self.epochs)) + + with open(timestamp_txt, 'a') as f: + f.write('[{}] end\n'.format(datetime.now().strftime("%Y-%m-%d-%H:%M"))) + + \ No newline at end of file