-
Notifications
You must be signed in to change notification settings - Fork 3
/
test.py
122 lines (91 loc) · 4.18 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import matplotlib.pyplot as plt
import skimage.io
#from PIL import Image
import utils
from constants import input_width, input_height,\
scale_fact, verbosity, get_model_save_path, tests_path
def run_tests(model, history):
print("\nExtracting the History (Callback) of metrics to display graph.")
# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.show()
# Then we run a few self-designed tests
x_test, y_test = extract_tests()
evaluate(model, x_test, y_test)
predicts(model, x_test, y_test)
def extract_tests():
print("Starting tests. Extracting images to feed.")
x = []
y = []
for i in range(11):
# Extracting the benchmark images (HR)
y_test = utils.crop_center(skimage.io.imread(tests_path + str(i) + ".png"),
input_width * scale_fact,
input_height * scale_fact)
y.append(y_test)
# Extracting middle part for prediction test
x.append(utils.single_downscale(y_test, input_width * scale_fact, input_height * scale_fact))
return np.array(x), np.array(y)
def evaluate(model, x_test, y_test):
print("Starting evaluation.")
test_loss = model.evaluate(x_test,
y_test,
batch_size=1,
verbose=verbosity)
print('[evaluate] Test loss:', test_loss)
#print('[evaluate] Test accuracy:', test_acc)
# score = model.evaluate(x_test, y_test, verbose=False)
# model.metrics_names
# print('Test score: ', score[0]) # Loss on test
# print('Test accuracy: ', score[1])
def predicts(model, x_test, y_test):
print("Starting predictions.")
# # Trying to make predictions on a bunch of images (works in batches)
# predictions = model.predict(images)
# Extracting predictions
predictions = []
for i in range(len(x_test)):
input_img = (np.expand_dims(x_test[i], 0)) # Add the image to a batch where it's the only member
predictions.append(model.predict(input_img)[0]) # returns a list of lists, one for each image in the batch
# # Taking the BICUBIC enlargment TODO: figure out without taking the file from path again
# bic1 = Image.open(data_path + '11.jpg').thumbnail((img_width, img_height), Image.BICUBIC)
# bic2 = Image.open(data_path + '12.jpg').thumbnail((img_width, img_height), Image.BICUBIC)
# # TODO: Saving predictions
# i = 0
# save_path = "pictures/final_tests/predictions/"
# print("Saving the 4 outputs as images")
# for pred in predictions:
# utils.save_np_img(pred, save_path, str(i) + ".png")
# i += 1
# https://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.savefig
# plt.savefig('pictures/final_tests/predictions/results.png', frameon=True) TODO: not working (white image)
# Showing output vs expected image
for i in range(len(predictions)):
show_pred_output(x_test[i], predictions[i], y_test[i])
prompt_model_save(model)
def show_pred_output(input, pred, truth):
plt.figure(figsize=(20, 20))
plt.suptitle("Results")
plt.subplot(1, 3, 1)
plt.title("Input: " + str(input_width) + "x" + str(input_height))
plt.imshow(input, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)
plt.subplot(1, 3, 2)
plt.title("Output: " + str(input_width * scale_fact) + "x" + str(input_height * scale_fact))
plt.imshow(pred, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)
plt.subplot(1, 3, 3)
plt.title("Target (HR): " + str(input_width * scale_fact) + "x" + str(input_height * scale_fact))
plt.imshow(truth, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)
plt.show()
def prompt_model_save(model):
save_bool = input("Save progress from this model (y/n) ?\n")
if save_bool == "y":
model.save(get_model_save_path())
print("Model saved! :)")
# model.save_weights('save/model_weights.h5')
del model # deletes the existing model