Skip to content

Commit

Permalink
Update infer_loader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiaobin-Rong authored May 14, 2024
1 parent 2139fad commit 6ce40ce
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions infer_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import toml
import torch
import pandas as pd
Expand All @@ -15,11 +13,11 @@


@torch.no_grad()
def infer(cfg_yaml):
def infer_loader(cfg_yaml):

save_wavs = input('>>> Save wavs? (y/n) ')
if save_wavs == 'y':
mark = input('>>> Please enter a tag for the saved wav names: ')
tag = input('>>> Please enter a tag for the saved wav names: ')

cfg_toml = toml.load(cfg_yaml.network.cfg_toml)
cfg_toml['validation_dataset']['train_folder'] = '/data/ssd0/xiaobin.rong/Datasets/DNS3/test/'
Expand All @@ -35,7 +33,6 @@ def infer(cfg_yaml):

validation_dataset = MyDataset(**cfg_toml['validation_dataset'])
validation_filename = validation_dataset.file_name

validation_dataloader = torch.utils.data.DataLoader(validation_dataset, **cfg_toml['validation_dataloader'])

### load model
Expand All @@ -56,7 +53,7 @@ def infer(cfg_yaml):
mixture = mixture.to(device)
target = target.to(device)

estimate= model(mixture) # [B, F, T, 2]
estimate= model(mixture)

enhanced = torch.istft(estimate[..., 0] + 1j*estimate[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device))
clean = torch.istft(target[..., 0] + 1j*target[..., 1], **cfg_toml['FFT'], window=torch.hann_window(cfg_toml['FFT']['win_length']).pow(0.5).to(device))
Expand All @@ -73,7 +70,7 @@ def infer(cfg_yaml):

## save wavs
if save_wavs == 'y':
save_name = "{}_{}_{:.2f}_{:.2f}_{:.2f}.wav".format(validation_filename[step], mark, sisnr_score, pesq_score, estoi_score)
save_name = "{}_{}_{:.2f}_{:.2f}_{:.2f}.wav".format(validation_filename[step], tag, sisnr_score, pesq_score, estoi_score)

sf.write(
os.path.join(netout_folder, save_name), out, cfg_toml['listener']['listener_sr'])
Expand All @@ -93,6 +90,14 @@ def infer(cfg_yaml):


if __name__ == "__main__":
cfg_yaml = OmegaConf.load('config.yaml')
infer(cfg_yaml)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-C', '--config', default='cfg_yaml.yaml')
parser.add_argument('-D', '--device', default='0', help='The index of the available device, only single GPU supported')

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.device

cfg_yaml = OmegaConf.load(args.config)
infer_loader(cfg_yaml)

0 comments on commit 6ce40ce

Please sign in to comment.