forked from snapfinger/pancreas-seg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
testvis.py
183 lines (137 loc) · 6.17 KB
/
testvis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
This code is to test NN model and visualize output
"""
import numpy as np
import sys
import time
import matplotlib.pyplot as plt
from keras.models import Model, load_model
from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, ZeroPadding2D, BatchNormalization
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint
from keras import backend as K
import tensorflow as tf
from data import load_train_data, load_test_data
from utils import *
K.set_image_data_format('channels_last') # Tensorflow dimension ordering
data_path = sys.argv[1] + "/"
model_path = data_path + "models/"
# dir for storing results that contains
rst_path = data_path + "test-records/"
if not os.path.exists(rst_path):
os.makedirs(rst_path)
model_to_test = sys.argv[2]
cur_fold = sys.argv[3]
plane = sys.argv[4]
im_z = int(sys.argv[5])
im_y = int(sys.argv[6])
im_x = int(sys.argv[7])
high_range = float(sys.argv[8])
low_range = float(sys.argv[9])
margin = int(sys.argv[10])
vis = sys.argv[11]
# prediction of trained model
pred_path = os.path.join(rst_path, "pred-%s/"%cur_fold)
if not os.path.exists(pred_path):
os.makedirs(pred_path)
"""
Dice Ceofficient and Cost functions for training
"""
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
def test(model_to_test, current_fold, plane, rst_dir, vis):
print "-"*50
print "loading model ", model_to_test
print "-"*50
model = load_model(model_path + model_to_test + '.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef':dice_coef})
volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines()
total = len(volume_list)
dsc = np.zeros((total, 2))
# iterate all test cases
for i in range(total):
s = volume_list[i].split(' ')
image = np.load(s[1])
label = np.load(s[2])
case_num = s[1].split("00")[1].split(".")[0]
print "testing case: ", case_num
image_ = np.transpose(image, (2, 0, 1))
label_ = np.transpose(label, (2, 0, 1))
# standardize test data
image_[image_ < low_range] = low_range
image_[image_ > high_range] = high_range
image_ = (image_ - low_range) / float(high_range - low_range)
# for creating final prediction visualization
pred = np.zeros_like(image_)
for sli in range(label_.shape[0]):
try:
# crop each slice according to smallest bounding box of each slice
width = label_[sli].shape[0]
height = label_[sli].shape[1]
arr = np.nonzero(label_[sli])
if len(arr[0]) == 0:
continue
minA = min(arr[0])
maxA = max(arr[0])
minB = min(arr[1])
maxB = max(arr[1])
minAdiff = margin
maxAdiff = margin
minBdiff = margin
maxBdiff = margin
cropped = image_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \
max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)]
cropped_mask = label_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \
max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)]
image_padded_ = pad_2d(cropped, plane, 0, im_x, im_y, im_z)
mask_padded_ = pad_2d(cropped_mask, plane, 0, im_x, im_y, im_z)
image_padded_prep = preprocess_front(preprocess(image_padded_))
out_ori = (model.predict(image_padded_prep) > 0.5).astype(np.uint8)
out = out_ori[:,0:cropped.shape[0], 0:cropped.shape[1],:].reshape(cropped.shape)
pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)] = out
pred_vis = pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)]
if vis == "true":
fig = plt.figure()
ax = fig.add_subplot(1, 3, 1)
ax.set_title("input test image")
ax.imshow(cropped, cmap=plt.cm.gray)
ax = fig.add_subplot(1, 3, 2)
ax.set_title("prediction")
ax.imshow(pred_vis, cmap=plt.cm.gray)
ax = fig.add_subplot(1, 3, 3)
ax.set_title("ground truth")
ax.imshow(cropped_mask, cmap=plt.cm.gray)
# plt.suptitle("slice %s"%sli)
fig.canvas.set_window_title("slice %s"%sli)
plt.axis('off')
plt.show()
except KeyboardInterrupt:
print 'KeyboardInterrupt caught'
raise ValueError("terminate because of keyboard interruption")
# ------------ write out for visualization ---------------
np.save(pred_path + case_num + ".npy", pred) # prediction made by the trained model
# compute DSC
cur_dsc, _, _, _ = DSC_computation(label_, pred)
print cur_dsc
dsc[i][0] = case_num
dsc[i][1] = cur_dsc
dsc_mean = np.mean(dsc[:,1])
dsc_std = np.std(dsc[:,1])
# record test dsc mean and standard deviation for each fold in the one file
fd = open(rst_path + 'test_stats.csv','a+')
fd.write("%s,%s,%s,%s\n"%(cur_fold, model_to_test, dsc_mean, dsc_std))
fd.close()
print "---------------------------------"
print "mean: ", dsc_mean
print "std: ", dsc_std
# record test result case by case
np.savetxt(rst_path + model_to_test + ".csv", dsc, fmt = "%i, %.5f", delimiter=",", header="case_num,DSC")
if __name__ == "__main__":
start_time = time.time()
test(model_to_test, cur_fold, plane, rst_path, vis)
print "-----------test done, total time used: %s ------------"% (time.time() - start_time)