-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_test.py
executable file
·65 lines (49 loc) · 1.97 KB
/
run_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import print_function
from __future__ import division
import click
import json
import os
import numpy as np
import SimpleITK as sitk
from keras.optimizers import Adam
from evaluation import getDSC, getHausdorff, getVS
from models.DRUNet import get_model
from metrics import dice_coef, dice_coef_loss
def get_eval_metrics(true_mask, pred_mask, output_file=''):
true_mask_sitk = sitk.GetImageFromArray(true_mask)
pred_mask_sitk = sitk.GetImageFromArray(pred_mask)
dsc = getDSC(true_mask_sitk, pred_mask_sitk)
h95 = getHausdorff(true_mask_sitk, pred_mask_sitk)
vs = getVS(true_mask_sitk, pred_mask_sitk)
result = {}
result['dsc'] = dsc
result['h95'] = h95
result['vs'] = vs
if output_file != '':
with open(output_file, 'w+') as outfile:
json.dump(result, outfile)
return (dsc, h95, vs)
@click.command()
@click.argument('test_imgs_np_file', type=click.STRING)
@click.argument('test_masks_np_file', type=click.STRING)
@click.argument('pretrained_model', type=click.STRING)
@click.option('--output_pred_mask_file', type=click.STRING, default='')
@click.option('--output_metric_file', type=click.STRING, default='')
def main(test_imgs_np_file, test_masks_np_file, pretrained_model, output_pred_mask_file='', output_metric_file=''):
num_classes = 9
# learn_rate = 1e-5
test_imgs = np.load(test_imgs_np_file)
test_masks = np.load(test_masks_np_file)
test_masks = test_masks[:, :, :, 0]
img_shape = (test_imgs.shape[1], test_imgs.shape[2], 1)
model = get_model(img_shape=img_shape, num_classes=num_classes)
assert os.path.isfile(pretrained_model)
model.load_weights(pretrained_model)
pred_masks = model.predict(test_imgs)
pred_masks = pred_masks.argmax(axis=3)
dsc, h95, vs = get_eval_metrics(test_masks, pred_masks, output_metric_file)
if output_pred_mask_file != '':
np.save(output_pred_mask_file, pred_masks)
return (dsc, h95, vs)
if __name__ == '__main__':
main()