forked from Xiaobin-Rong/SEtrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
75 lines (52 loc) · 2.51 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import toml
import random
import torch
import pandas as pd
import soundfile as sf
from torch.utils import data
class MyDataset(data.Dataset):
def __init__(self, train_folder, shuffle, num_tot, wav_len=0, n_fft=512, hop_length=256, win_length=512):
super().__init__()
### We store the noisy-clean pairs in the same folder, and use CSV file to manage all the WAV files.
self.file_name = pd.read_csv(os.path.join(train_folder, 'INFO.csv'))['file_name'].to_list()
if shuffle:
random.seed(7)
random.shuffle(self.file_name)
if num_tot != 0:
self.file_name = self.file_name[: num_tot]
self.train_folder = train_folder
self.wav_len = wav_len
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
def __getitem__(self, idx):
noisy, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_noisy.wav'), dtype="float32")
clean, fs = sf.read(os.path.join(self.train_folder, self.file_name[idx] + '_clean.wav'), dtype="float32")
noisy = torch.tensor(noisy)
clean = torch.tensor(clean)
if self.wav_len != 0:
start = random.choice(range(len(clean) - self.wav_len * fs))
noisy = noisy[start: start + self.wav_len*fs]
clean = clean[start: start + self.wav_len*fs]
noisy = torch.stft(noisy, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False)
clean = torch.stft(clean, self.n_fft, self.hop_length, self.win_length, torch.hann_window(self.win_length).pow(0.5), return_complex=False)
return noisy, clean
def __len__(self):
return len(self.file_name)
if __name__=='__main__':
from tqdm import tqdm
config = toml.load('config.toml')
device = torch.device('cuda')
train_dataset = MyDataset(**config['train_dataset'], **config['FFT'])
train_dataloader = data.DataLoader(train_dataset, **config['train_dataloader'])
validation_dataset = MyDataset(**config['validation_dataset'], **config['FFT'])
validation_dataloader = data.DataLoader(validation_dataset, **config['validation_dataloader'])
print(len(train_dataloader), len(validation_dataloader))
for noisy, clean in tqdm(train_dataloader):
print(noisy.shape, clean.shape)
break
# pass
for noisy, clean in tqdm(validation_dataloader):
print(noisy.shape, clean.shape)
break