forked from Xiaobin-Rong/SEtrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5e25a55
commit f2701f6
Showing
15 changed files
with
1,258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# A code template for training DNN-based speech enhancement models. | ||
A training code template is highly valuable for deep learning engineers as it can significantly enhance their work efficiency. Despite different individuals have varying coding styles, some are excellent while others may not be as good. My philosophy is to prioritize simplicity. In this context, I am sharing a practical organizational structure for training code files in speech enhancement (SE). The primary focus is on keeping it concise and intuitive rather than aiming for comprehensiveness. | ||
|
||
## File Specification | ||
For training: | ||
* `config.toml`: Specifies the training configurations. | ||
* `datasets.py`: Provides the dataset class for the dataloader. | ||
* `distributed_utils.py`: Assists with Distributed Data Parallel (DDP) training. | ||
* `model.py`: Defines the model. | ||
* `loss_factory.py`: Provides various useful loss functions in SE. | ||
* `train_sg.py`: Conducts the training process for a single GPU machine. | ||
* `train.py`: Conducts the training process for multiple GPUs. | ||
* `trainer_sg.py`: Encapsulates various functions during training for a single GPU machine. | ||
* `trainer.py`: Encapsulates various functions during training for multiple GPUs. | ||
|
||
For evaluation: | ||
* `config.yaml`: Specifies evaluation paths. | ||
* `infer_folder.py`: Conducts evaluation on a folder of WAV files. | ||
* `infer_loader.py`: Conducts evaluation using a dataloader. | ||
* `score_utils.py`: Provides calculations for various metrics. | ||
|
||
## Usage | ||
When starting a new SE project, you should follow these steps: | ||
1. Modify `datasets.py`; | ||
2. Define your own `model.py`; | ||
3. Modify the `config.toml` to match your training setup; | ||
4. Select a loss function in `loss_factory.py`, or create a new one if needed; | ||
5. Probably do not need to modify `trainer.py` or `trainer_sg.py`; | ||
6. Run the `train.py` or `train_sg.py` based on the number of available GPUs. | ||
7. Before evaluation, remember to modify `config.yaml` to ensure that the paths are correctly configured. | ||
|
||
## Note | ||
The code is originally intended for Linux systems, and if you attempt to adapt it to the Windows platform, you may encounter certain issues: | ||
* Incompatibility of paths: The file paths used in Linux systems may not be compatible with the file paths in Windows. | ||
|
||
* Challenges in installing the pesq package: The process of installing the pesq package on Windows may not be straightforward and may require additional steps or configurations. | ||
|
||
Please keep these considerations in mind when working with the code on the Windows platform. | ||
|
||
## Acknowledgement | ||
This code template heavily borrows from the excellent [Sheffield_Clarity_CEC1_Entry](https://github.com/TuZehai/Sheffield_Clarity_CEC1_Entry) reposity in many aspects. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
[network_config] | ||
|
||
|
||
[DDP] | ||
world_size = 2 # number of available gpus | ||
|
||
[optimizer] | ||
lr = 1e-3 | ||
|
||
[loss] | ||
loss_func = 'hybrid' | ||
|
||
[listener] | ||
listener_sr = 16000 | ||
|
||
[FFT] | ||
n_fft = 512 | ||
hop_length = 256 | ||
win_length = 512 | ||
|
||
[train_dataset] | ||
train_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/train' | ||
shuffle = false | ||
num_tot = 0 | ||
wav_len = 0 | ||
|
||
[train_dataloader] | ||
batch_size = 16 | ||
num_workers = 4 | ||
drop_last = true | ||
pin_memory = true | ||
|
||
[validation_dataset] | ||
train_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/dev' | ||
shuffle = false | ||
num_tot = 0 | ||
wav_len = 0 | ||
|
||
[validation_dataloader] | ||
batch_size = 1 | ||
num_workers = 4 | ||
pin_memory = true | ||
|
||
[trainer] | ||
epochs = 120 | ||
save_checkpoint_interval = 1 | ||
clip_grad_norm_value = 3.0 | ||
exp_path = '/data/ssd0/xiaobin.rong/project_se/DNS3/exp_dpcrn' | ||
resume = false | ||
resume_datetime = '' | ||
resume_step = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
path: | ||
exp_folder: ${network.root}/${network.exp_name}/${network.cpt_name}/enhanced # where enhanced speech store | ||
|
||
network: | ||
root: /data/ssd0/xiaobin.rong/project_se/DNS3 | ||
exp_name: exp_dpcrn_2023-08-02-11h32m | ||
cpt_name: model_0082 | ||
checkpoint: ${network.exp_path}/checkpoints/${network.cpt_name}.tar | ||
exp_path: ${network.root}/${network.exp_name} | ||
cfg_toml: ${network.exp_path}/config.toml | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import os | ||
import toml | ||
import random | ||
import torch | ||
import pandas as pd | ||
import soundfile as sf | ||
from torch.utils import data | ||
|
||
|
||
class MyDataset(data.Dataset): | ||
def __init__(self, train_folder, shuffle, num_tot, wav_len=0, n_fft=512, hop_length=256, win_length=512): | ||
super().__init__() | ||
### We store the noisy-clean pairs in the same folder, and use CSV file to manage all the WAV files. | ||
self.file_name = pd.read_csv(os.path.join(train_folder, 'INFO.csv'))['file_name'].to_list() | ||
|
||
if shuffle: | ||
random.seed(7) | ||
random.shuffle(self.file_name) | ||
|
||
if num_tot != 0: | ||
self.file_name = self.file_name[: num_tot] | ||
|
||
self.train_folder = train_folder | ||
self.wav_len = wav_len | ||
|
||
self.n_fft = n_fft | ||
self.hop_length = hop_length | ||
self.win_length = win_length | ||
|
||
def __getitem__(self, idx): | ||
noisy, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_noisy.wav'), dtype="float32") | ||
clean, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_clean.wav'), dtype="float32") | ||
|
||
noisy = torch.tensor(noisy) | ||
clean = torch.tensor(clean) | ||
|
||
if self.wav_len != 0: | ||
start = random.choice(range(len(clean) - self.wav_len * fs)) | ||
noisy = noisy[start: start + self.wav_len*fs] | ||
clean = clean[start: start + self.wav_len*fs] | ||
|
||
noisy = torch.stft(noisy, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False) | ||
clean = torch.stft(clean, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False) | ||
|
||
return noisy, clean | ||
|
||
def __len__(self): | ||
return len(self.file_name) | ||
|
||
|
||
if __name__=='__main__': | ||
from tqdm import tqdm | ||
config = toml.load('config.toml') | ||
|
||
device = torch.device('cuda') | ||
|
||
train_dataset = MyDataset(**config['train_dataset'], **config['FFT']) | ||
train_dataloader = data.DataLoader(train_dataset, **config['train_dataloader']) | ||
|
||
validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT']) | ||
validation_dataloader = data.DataLoader(validation_dataset, **config['validation_dataloader']) | ||
|
||
print(len(train_dataloader), len(validation_dataloader)) | ||
|
||
for noisy, clean in tqdm(train_dataloader): | ||
print(noisy.shape, clean.shape) | ||
break | ||
# pass | ||
|
||
for noisy, clean in tqdm(validation_dataloader): | ||
print(noisy.shape, clean.shape) | ||
break | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
def init_distributed_mode(args): | ||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | ||
args.rank = int(os.environ["RANK"]) | ||
args.world_size = int(os.environ['WORLD_SIZE']) | ||
args.gpu = int(os.environ['LOCAL_RANK']) | ||
elif 'SLURM_PROCID' in os.environ: | ||
args.rank = int(os.environ['SLURM_PROCID']) | ||
args.gpu = args.rank % torch.cuda.device_count() | ||
else: | ||
print('Not using distributed mode') | ||
args.distributed = False | ||
return | ||
|
||
args.distributed = True | ||
|
||
torch.cuda.set_device(args.gpu) | ||
args.dist_backend = 'nccl' # Communication backend, NVIDIA GPUs are recommended to use NCCL. | ||
print('| distributed init (rank {}): {}'.format( | ||
args.rank, args.dist_url), flush=True) | ||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, | ||
world_size=args.world_size, rank=args.rank) | ||
dist.barrier() | ||
|
||
|
||
def cleanup(): | ||
dist.destroy_process_group() | ||
|
||
|
||
def is_dist_avail_and_initialized(): | ||
""" Check if distributed environment is supported. """ | ||
if not dist.is_available(): | ||
return False | ||
if not dist.is_initialized(): | ||
return False | ||
return True | ||
|
||
|
||
def get_world_size(): | ||
if not is_dist_avail_and_initialized(): | ||
return 1 | ||
return dist.get_world_size() | ||
|
||
|
||
def get_rank(): | ||
if not is_dist_avail_and_initialized(): | ||
return 0 | ||
return dist.get_rank() | ||
|
||
|
||
def is_main_process(): | ||
return get_rank() == 0 | ||
|
||
|
||
def reduce_value(value, average=True): | ||
world_size = get_world_size() | ||
if world_size < 2: # for single GPU | ||
return value | ||
|
||
with torch.no_grad(): | ||
dist.all_reduce(value) | ||
if average: | ||
value /= world_size | ||
|
||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
conduct evaluation on a folder of WAV files, with computing SISNR, PESQ, ESTOI, and DNSMOS. | ||
""" | ||
import os | ||
os.environ["CUDA_VISIBLE_DEVICES"]="0" | ||
import toml | ||
import torch | ||
from tqdm import tqdm | ||
import soundfile as sf | ||
from omegaconf import OmegaConf | ||
from model import DPCRN | ||
|
||
|
||
cfg_yaml = OmegaConf.load('config.yaml') | ||
test_folder = '/data/ssd0/xiaobin.rong/Datasets/DNS3/blind_test_set/dns-challenge-3-final-evaluation/wideband_16kHz/noisy_clips_wb_16kHz/' | ||
test_wavnames = list(filter(lambda x: x.endswith("wav"), os.listdir(test_folder))) | ||
|
||
cfg_toml = toml.load(cfg_yaml.network.cfg_toml) | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
netout_folder = f'{cfg_yaml.path.exp_folder}' | ||
os.makedirs(netout_folder, exist_ok=True) | ||
|
||
### load model | ||
model = DPCRN(**cfg_toml['network_config']) | ||
model.to(device) | ||
checkpoint = torch.load(cfg_yaml.network.checkpoint, map_location=device) | ||
model.load_state_dict(checkpoint['model']) | ||
model.eval() | ||
|
||
for param in model.parameters(): | ||
param.requires_grad = False | ||
|
||
### compute SISNR, PESQ and ESTOI | ||
with torch.no_grad(): | ||
for name in tqdm(test_wavnames): | ||
noisy, fs = sf.read(os.path.join(test_folder, name), dtype="float32") | ||
noisy = torch.stft(torch.from_numpy(noisy), **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5)) | ||
noisy = noisy.to(device) | ||
|
||
estimate= model(noisy[None, ...]) # (B,F,T,2) | ||
|
||
enhanced = torch.istft(estimate[..., 0] + 1j*estimate[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device)) | ||
out = enhanced.cpu().detach().numpy().squeeze() | ||
|
||
sf.write(os.path.join(netout_folder, name[:-4]+'_enh.wav'), out, fs) | ||
|
||
### compute DNSMOS | ||
os.chdir('DNSMOS') | ||
out_dir = os.path.join(netout_folder, 'dnsmos_enhanced_p808.csv') | ||
os.system(f'python dnsmos_local_p808.py -t {netout_folder} -o {out_dir}') | ||
|
Oops, something went wrong.