Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiaobin-Rong authored Nov 26, 2023
1 parent 5e25a55 commit f2701f6
Show file tree
Hide file tree
Showing 15 changed files with 1,258 additions and 0 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# A code template for training DNN-based speech enhancement models.
A training code template is highly valuable for deep learning engineers as it can significantly enhance their work efficiency. Despite different individuals have varying coding styles, some are excellent while others may not be as good. My philosophy is to prioritize simplicity. In this context, I am sharing a practical organizational structure for training code files in speech enhancement (SE). The primary focus is on keeping it concise and intuitive rather than aiming for comprehensiveness.

## File Specification
For training:
* `config.toml`: Specifies the training configurations.
* `datasets.py`: Provides the dataset class for the dataloader.
* `distributed_utils.py`: Assists with Distributed Data Parallel (DDP) training.
* `model.py`: Defines the model.
* `loss_factory.py`: Provides various useful loss functions in SE.
* `train_sg.py`: Conducts the training process for a single GPU machine.
* `train.py`: Conducts the training process for multiple GPUs.
* `trainer_sg.py`: Encapsulates various functions during training for a single GPU machine.
* `trainer.py`: Encapsulates various functions during training for multiple GPUs.

For evaluation:
* `config.yaml`: Specifies evaluation paths.
* `infer_folder.py`: Conducts evaluation on a folder of WAV files.
* `infer_loader.py`: Conducts evaluation using a dataloader.
* `score_utils.py`: Provides calculations for various metrics.

## Usage
When starting a new SE project, you should follow these steps:
1. Modify `datasets.py`;
2. Define your own `model.py`;
3. Modify the `config.toml` to match your training setup;
4. Select a loss function in `loss_factory.py`, or create a new one if needed;
5. Probably do not need to modify `trainer.py` or `trainer_sg.py`;
6. Run the `train.py` or `train_sg.py` based on the number of available GPUs.
7. Before evaluation, remember to modify `config.yaml` to ensure that the paths are correctly configured.

## Note
The code is originally intended for Linux systems, and if you attempt to adapt it to the Windows platform, you may encounter certain issues:
* Incompatibility of paths: The file paths used in Linux systems may not be compatible with the file paths in Windows.

* Challenges in installing the pesq package: The process of installing the pesq package on Windows may not be straightforward and may require additional steps or configurations.

Please keep these considerations in mind when working with the code on the Windows platform.

## Acknowledgement
This code template heavily borrows from the excellent [Sheffield_Clarity_CEC1_Entry](https://github.com/TuZehai/Sheffield_Clarity_CEC1_Entry) reposity in many aspects.
51 changes: 51 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[network_config]


[DDP]
world_size = 2 # number of available gpus

[optimizer]
lr = 1e-3

[loss]
loss_func = 'hybrid'

[listener]
listener_sr = 16000

[FFT]
n_fft = 512
hop_length = 256
win_length = 512

[train_dataset]
train_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/train'
shuffle = false
num_tot = 0
wav_len = 0

[train_dataloader]
batch_size = 16
num_workers = 4
drop_last = true
pin_memory = true

[validation_dataset]
train_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/dev'
shuffle = false
num_tot = 0
wav_len = 0

[validation_dataloader]
batch_size = 1
num_workers = 4
pin_memory = true

[trainer]
epochs = 120
save_checkpoint_interval = 1
clip_grad_norm_value = 3.0
exp_path = '/data/ssd0/xiaobin.rong/project_se/DNS3/exp_dpcrn'
resume = false
resume_datetime = ''
resume_step = 0
11 changes: 11 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
path:
exp_folder: ${network.root}/${network.exp_name}/${network.cpt_name}/enhanced # where enhanced speech store

network:
root: /data/ssd0/xiaobin.rong/project_se/DNS3
exp_name: exp_dpcrn_2023-08-02-11h32m
cpt_name: model_0082
checkpoint: ${network.exp_path}/checkpoints/${network.cpt_name}.tar
exp_path: ${network.root}/${network.exp_name}
cfg_toml: ${network.exp_path}/config.toml

75 changes: 75 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
import toml
import random
import torch
import pandas as pd
import soundfile as sf
from torch.utils import data


class MyDataset(data.Dataset):
def __init__(self, train_folder, shuffle, num_tot, wav_len=0, n_fft=512, hop_length=256, win_length=512):
super().__init__()
### We store the noisy-clean pairs in the same folder, and use CSV file to manage all the WAV files.
self.file_name = pd.read_csv(os.path.join(train_folder, 'INFO.csv'))['file_name'].to_list()

if shuffle:
random.seed(7)
random.shuffle(self.file_name)

if num_tot != 0:
self.file_name = self.file_name[: num_tot]

self.train_folder = train_folder
self.wav_len = wav_len

self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length

def __getitem__(self, idx):
noisy, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_noisy.wav'), dtype="float32")
clean, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_clean.wav'), dtype="float32")

noisy = torch.tensor(noisy)
clean = torch.tensor(clean)

if self.wav_len != 0:
start = random.choice(range(len(clean) - self.wav_len * fs))
noisy = noisy[start: start + self.wav_len*fs]
clean = clean[start: start + self.wav_len*fs]

noisy = torch.stft(noisy, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False)
clean = torch.stft(clean, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False)

return noisy, clean

def __len__(self):
return len(self.file_name)


if __name__=='__main__':
from tqdm import tqdm
config = toml.load('config.toml')

device = torch.device('cuda')

train_dataset = MyDataset(**config['train_dataset'], **config['FFT'])
train_dataloader = data.DataLoader(train_dataset, **config['train_dataloader'])

validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT'])
validation_dataloader = data.DataLoader(validation_dataset, **config['validation_dataloader'])

print(len(train_dataloader), len(validation_dataloader))

for noisy, clean in tqdm(train_dataloader):
print(noisy.shape, clean.shape)
break
# pass

for noisy, clean in tqdm(validation_dataloader):
print(noisy.shape, clean.shape)
break



70 changes: 70 additions & 0 deletions distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os

import torch
import torch.distributed as dist


def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return

args.distributed = True

torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl' # Communication backend, NVIDIA GPUs are recommended to use NCCL.
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
dist.barrier()


def cleanup():
dist.destroy_process_group()


def is_dist_avail_and_initialized():
""" Check if distributed environment is supported. """
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True


def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()


def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()


def is_main_process():
return get_rank() == 0


def reduce_value(value, average=True):
world_size = get_world_size()
if world_size < 2: # for single GPU
return value

with torch.no_grad():
dist.all_reduce(value)
if average:
value /= world_size

return value
52 changes: 52 additions & 0 deletions infer_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
conduct evaluation on a folder of WAV files, with computing SISNR, PESQ, ESTOI, and DNSMOS.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import toml
import torch
from tqdm import tqdm
import soundfile as sf
from omegaconf import OmegaConf
from model import DPCRN


cfg_yaml = OmegaConf.load('config.yaml')
test_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/blind_test_set/dns-challenge-3-final-evaluation/wideband_16kHz/noisy_clips_wb_16kHz/'
test_wavnames = list(filter(lambda x: x.endswith("wav"), os.listdir(test_folder)))

cfg_toml = toml.load(cfg_yaml.network.cfg_toml)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

netout_folder = f'{cfg_yaml.path.exp_folder}'
os.makedirs(netout_folder, exist_ok=True)

### load model
model = DPCRN(**cfg_toml['network_config'])
model.to(device)
checkpoint = torch.load(cfg_yaml.network.checkpoint, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()

for param in model.parameters():
param.requires_grad = False

### compute SISNR, PESQ and ESTOI
with torch.no_grad():
for name in tqdm(test_wavnames):
noisy, fs = sf.read(os.path.join(test_folder, name), dtype="float32")
noisy = torch.stft(torch.from_numpy(noisy), **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5))
noisy = noisy.to(device)

estimate= model(noisy[None, ...]) # (B,F,T,2)

enhanced = torch.istft(estimate[..., 0] + 1j*estimate[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device))
out = enhanced.cpu().detach().numpy().squeeze()

sf.write(os.path.join(netout_folder, name[:-4]+'_enh.wav'), out, fs)

### compute DNSMOS
os.chdir('DNSMOS')
out_dir = os.path.join(netout_folder, 'dnsmos_enhanced_p808.csv')
os.system(f'python dnsmos_local_p808.py -t {netout_folder} -o {out_dir}')

Loading

0 comments on commit f2701f6

Please sign in to comment.