From 6ce40ce91c11197d6670b9669f79f164477866d6 Mon Sep 17 00:00:00 2001 From: Rong Xiaobin Date: Tue, 14 May 2024 11:13:28 +0800 Subject: [PATCH] Update infer_loader.py --- infer_loader.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/infer_loader.py b/infer_loader.py index 157612c..19382e8 100644 --- a/infer_loader.py +++ b/infer_loader.py @@ -1,6 +1,4 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"]="1" - import toml import torch import pandas as pd @@ -15,11 +13,11 @@ @torch.no_grad() -def infer(cfg_yaml): +def infer_loader(cfg_yaml): save_wavs = input('>>> Save wavs? (y/n) ') if save_wavs == 'y': - mark = input('>>> Please enter a tag for the saved wav names: ') + tag = input('>>> Please enter a tag for the saved wav names: ') cfg_toml = toml.load(cfg_yaml.network.cfg_toml) cfg_toml['validation_dataset']['train_folder'] = '/data/ssd0/xiaobin.rong/Datasets/DNS3/test/' @@ -35,7 +33,6 @@ def infer(cfg_yaml): validation_dataset = MyDataset(**cfg_toml['validation_dataset']) validation_filename = validation_dataset.file_name - validation_dataloader = torch.utils.data.DataLoader(validation_dataset, **cfg_toml['validation_dataloader']) ### load model @@ -56,7 +53,7 @@ def infer(cfg_yaml): mixture = mixture.to(device) target = target.to(device) - estimate= model(mixture) # [B, F, T, 2] + estimate= model(mixture) enhanced = torch.istft(estimate[..., 0] + 1j*estimate[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device)) clean = torch.istft(target[..., 0] + 1j*target[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device)) @@ -73,7 +70,7 @@ def infer(cfg_yaml): ## save wavs if save_wavs == 'y': - save_name = "{}_{}_{:.2f}_{:.2f}_{:.2f}.wav".format(validation_filename[step], mark, sisnr_score, pesq_score, estoi_score) + save_name = "{}_{}_{:.2f}_{:.2f}_{:.2f}.wav".format(validation_filename[step], tag, sisnr_score, pesq_score, estoi_score) sf.write( os.path.join(netout_folder, save_name), out, cfg_toml['listener']['listener_sr']) @@ -93,6 +90,14 @@ def infer(cfg_yaml): if __name__ == "__main__": - cfg_yaml = OmegaConf.load('config.yaml') - infer(cfg_yaml) + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-C', '--config', default='cfg_yaml.yaml') + parser.add_argument('-D', '--device', default='0', help='The index of the available device, only single GPU supported') + + args = parser.parse_args() + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + + cfg_yaml = OmegaConf.load(args.config) + infer_loader(cfg_yaml)