-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathgenerate_train_val.py
46 lines (34 loc) · 1.34 KB
/
generate_train_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 21 13:40:34 2019
@author: Winham
generate_train_val.py: 生成训练集和验证集
注意:运行前新建文件夹train_sigs,train_labels,val_sigs,val_labels
"""
import os
import numpy as np
from sklearn.model_selection import train_test_split
Sig_path = 'G:/ECG_UNet/119_SEG/'
Label_path = 'G:/ECG_UNet/119_LABEL/'
train_sig_path = 'G:/ECG_UNet/train_sigs/'
train_label_path = 'G:/ECG_UNet/train_labels/'
val_sig_path = 'G:/ECG_UNet/val_sigs/'
val_label_path = 'G:/ECG_UNet/val_labels/'
sig_files = os.listdir(Sig_path)
label_files = os.listdir(Label_path)
sig_files.sort()
label_files.sort()
sig_train, sig_val, label_train, label_val = train_test_split(
sig_files, label_files, test_size=100, random_state=42) # 训练集500,验证集100
for i in range(len(sig_train)):
print('Train No.'+str(i+1)+':'+sig_train[i])
sig = np.load(Sig_path+sig_train[i])
label = np.load(Label_path+label_train[i])
np.save(train_sig_path+sig_train[i], sig)
np.save(train_label_path+label_train[i], label)
for i in range(len(sig_val)):
print('Val No.'+str(i+1)+':'+sig_val[i])
sig = np.load(Sig_path+sig_val[i])
label = np.load(Label_path+label_val[i])
np.save(val_sig_path+sig_val[i], sig)
np.save(val_label_path+label_val[i], label)