-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathCPSC_train_multi_leads.py
108 lines (87 loc) · 4.31 KB
/
CPSC_train_multi_leads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 26 14:41:17 2018
@author: Winham
# CPSC_train_multi_leads.py: 针对每个导联训练网络并保存模型,主体与CPSC_train_single_lead.py基本一致
"""
import os
import warnings
import numpy as np
import tensorflow as tf
from keras import backend as bk
from keras import optimizers
from keras.layers import Input
from keras.models import Model, load_model
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from sklearn.preprocessing import scale
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from CPSC_model import Net
from CPSC_config import Config
import CPSC_utils as utils
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
warnings.filterwarnings("ignore")
config = Config()
records_name = np.array(os.listdir(config.DATA_PATH))
records_label = np.load(config.REVISED_LABEL) - 1
class_num = len(np.unique(records_label))
train_val_records, test_records, train_val_labels, test_labels = train_test_split(
records_name, records_label, test_size=0.2, random_state=config.RANDOM_STATE)
del test_records, test_labels
train_records, val_records, train_labels, val_labels = train_test_split(
train_val_records, train_val_labels, test_size=0.2, random_state=config.RANDOM_STATE)
train_records, train_labels = utils.oversample_balance(train_records, train_labels, config.RANDOM_STATE)
val_records, val_labels = utils.oversample_balance(val_records, val_labels, config.RANDOM_STATE)
for i in range(config.LEAD_NUM):
TARGET_LEAD = i
print('Fetching data for Lead ' + str(TARGET_LEAD) + ' ...-----------------\n')
train_x = utils.Fetch_Pats_Lbs_sLead(train_records, Path=config.DATA_PATH,
target_lead=TARGET_LEAD, seg_num=config.SEG_NUM,
seg_length=config.SEG_LENGTH)
train_y = to_categorical(train_labels, num_classes=class_num)
val_x = utils.Fetch_Pats_Lbs_sLead(val_records, Path=config.DATA_PATH,
target_lead=TARGET_LEAD, seg_num=config.SEG_NUM,
seg_length=config.SEG_LENGTH)
val_y = to_categorical(val_labels, num_classes=class_num)
model_name = 'net_lead_' + str(TARGET_LEAD) + '.hdf5'
print('Scaling data ...-----------------\n')
for j in range(train_x.shape[0]):
train_x[j, :, :] = scale(train_x[j, :, :], axis=0)
for j in range(val_x.shape[0]):
val_x[j, :, :] = scale(val_x[j, :, :], axis=0)
batch_size = 64
epochs = 100
momentum = 0.9
keep_prob = 0.5
bk.clear_session()
tf.reset_default_graph()
inputs = Input(shape=(config.SEG_LENGTH, config.SEG_NUM))
net = Net()
outputs, _ = net.nnet(inputs, keep_prob, num_classes=class_num)
model = Model(inputs=inputs, outputs=outputs)
opt = optimizers.SGD(lr=config.lr_schedule(0), momentum=momentum)
model.compile(optimizer=opt, loss='categorical_crossentropy',
metrics=['categorical_accuracy'])
checkpoint = ModelCheckpoint(filepath=config.MODEL_PATH+model_name,
monitor='val_categorical_accuracy', mode='max',
save_best_only='True')
lr_scheduler = LearningRateScheduler(config.lr_schedule)
callback_lists = [checkpoint, lr_scheduler]
model.fit(x=train_x, y=train_y, batch_size=batch_size, epochs=epochs, verbose=2,
validation_data=(val_x, val_y), callbacks=callback_lists)
del train_x, train_y
model = load_model(config.MODEL_PATH + model_name)
pred_vt = model.predict(val_x, batch_size=batch_size, verbose=1)
pred_v = np.argmax(pred_vt, axis=1)
true_v = np.argmax(val_y, axis=1)
del val_x, val_y
Conf_Mat_val = confusion_matrix(true_v, pred_v)
print('\nResult for Lead ' + str(TARGET_LEAD) + '-----------------------------\n')
print(Conf_Mat_val)
F1s_val = []
for j in range(class_num):
f1t = 2 * Conf_Mat_val[j][j] / (np.sum(Conf_Mat_val[j, :]) + np.sum(Conf_Mat_val[:, j]))
print('| F1-' + config.CLASS_NAME[j] + ':' + str(f1t) + ' |')
F1s_val.append(f1t)
print('F1-mean: ' + str(np.mean(F1s_val)))