-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathmodel.py
179 lines (136 loc) · 6.83 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Authors:
# Christian F. Baumgartner ([email protected])
# Lisa M. Koch ([email protected])
import tensorflow as tf
from tfwrapper import losses
import tensorflow.examples.tutorials.mnist
def inference(images, exp_config, training):
'''
Wrapper function to provide an interface to a model from the model_zoo inside of the model module.
'''
return exp_config.model_handle(images, training, nlabels=exp_config.nlabels)
def loss(logits, labels, nlabels, loss_type, weight_decay=0.0):
'''
Loss to be minimised by the neural network
:param logits: The output of the neural network before the softmax
:param labels: The ground truth labels in standard (i.e. not one-hot) format
:param nlabels: The number of GT labels
:param loss_type: Can be 'weighted_crossentropy'/'crossentropy'/'dice'/'dice_onlyfg'/'crossentropy_and_dice'
:param weight_decay: The weight for the L2 regularisation of the network paramters
:return: The total loss including weight decay, the loss without weight decay, only the weight decay
'''
labels = tf.one_hot(labels, depth=nlabels)
with tf.variable_scope('weights_norm'):
weights_norm = tf.reduce_sum(
input_tensor = weight_decay*tf.stack(
[tf.nn.l2_loss(ii) for ii in tf.get_collection('weight_variables')]
),
name='weights_norm'
)
if loss_type == 'weighted_crossentropy':
segmentation_loss = losses.pixel_wise_cross_entropy_loss_weighted(logits, labels,
class_weights=[0.1, 0.3, 0.3, 0.3])
elif loss_type == 'crossentropy':
segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels)
elif loss_type == 'dice':
segmentation_loss = losses.dice_loss(logits, labels, only_foreground=False)
elif loss_type == 'dice_onlyfg':
segmentation_loss = losses.dice_loss(logits, labels, only_foreground=True)
elif loss_type == 'crossentropy_and_dice':
segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels) + 0.2*losses.dice_loss(logits, labels)
else:
raise ValueError('Unknown loss: %s' % loss_type)
total_loss = tf.add(segmentation_loss, weights_norm)
return total_loss, segmentation_loss, weights_norm
def predict(images, exp_config):
'''
Returns the prediction for an image given a network from the model zoo
:param images: An input image tensor
:param inference_handle: A model function from the model zoo
:return: A prediction mask, and the corresponding softmax output
'''
logits = exp_config.model_handle(images, training=tf.constant(False, dtype=tf.bool), nlabels=exp_config.nlabels)
softmax = tf.nn.softmax(logits)
mask = tf.arg_max(softmax, dimension=-1)
return mask, softmax
def training_step(loss, optimizer_handle, learning_rate, **kwargs):
'''
Creates the optimisation operation which is executed in each training iteration of the network
:param loss: The loss to be minimised
:param optimizer_handle: A handle to one of the tf optimisers
:param learning_rate: Learning rate
:param momentum: Optionally, you can also pass a momentum term to the optimiser.
:return: The training operation
'''
if 'momentum' in kwargs:
momentum = kwargs.get('momentum')
optimizer = optimizer_handle(learning_rate=learning_rate, momentum=momentum)
else:
optimizer = optimizer_handle(learning_rate=learning_rate)
# The with statement is needed to make sure the tf contrib version of batch norm properly performs its updates
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
return train_op
def evaluation(logits, labels, images, nlabels, loss_type):
'''
A function for evaluating the performance of the netwrok on a minibatch. This function returns the loss and the
current foreground Dice score, and also writes example segmentations and imges to to tensorboard.
:param logits: Output of network before softmax
:param labels: Ground-truth label mask
:param images: Input image mini batch
:param nlabels: Number of labels in the dataset
:param loss_type: Which loss should be evaluated
:return: The loss without weight decay, the foreground dice of a minibatch
'''
mask = tf.arg_max(tf.nn.softmax(logits, dim=-1), dimension=-1) # was 3
mask_gt = labels
tf.summary.image('example_gt', prepare_tensor_for_summary(mask_gt, mode='mask', nlabels=nlabels))
tf.summary.image('example_pred', prepare_tensor_for_summary(mask, mode='mask', nlabels=nlabels))
tf.summary.image('example_zimg', prepare_tensor_for_summary(images, mode='image'))
total_loss, nowd_loss, weights_norm = loss(logits, labels, nlabels=nlabels, loss_type=loss_type)
cdice_structures = losses.per_structure_dice(logits, tf.one_hot(labels, depth=nlabels))
cdice_foreground = cdice_structures[:,1:]
cdice = tf.reduce_mean(cdice_foreground)
return nowd_loss, cdice
def prepare_tensor_for_summary(img, mode, idx=0, nlabels=None):
'''
Format a tensor containing imgaes or segmentation masks such that it can be used with
tf.summary.image(...) and displayed in tensorboard.
:param img: Input image or segmentation mask
:param mode: Can be either 'image' or 'mask. The two require slightly different slicing
:param idx: Which index of a minibatch to display. By default it's always the first
:param nlabels: Used for the proper rescaling of the label values. If None it scales by the max label..
:return: Tensor ready to be used with tf.summary.image(...)
'''
if mode == 'mask':
if img.get_shape().ndims == 3:
V = img[idx, ...]
elif img.get_shape().ndims == 4:
V = img[idx, ..., 10]
elif img.get_shape().ndims == 5:
V = img[idx, ..., 10, 0]
else:
raise ValueError('Dont know how to deal with input dimension %d' % (img.get_shape().ndims))
elif mode == 'image':
if img.get_shape().ndims == 3:
V = img[idx, ...]
elif img.get_shape().ndims == 4:
V = img[idx, ..., 0]
elif img.get_shape().ndims == 5:
V = img[idx, ..., 10, 0]
else:
raise ValueError('Dont know how to deal with input dimension %d' % (img.get_shape().ndims))
else:
raise ValueError('Unknown mode: %s. Must be image or mask' % mode)
if mode=='image' or not nlabels:
V -= tf.reduce_min(V)
V /= tf.reduce_max(V)
else:
V /= (nlabels - 1) # The largest value in a label map is nlabels - 1.
V *= 255
V = tf.cast(V, dtype=tf.uint8)
img_w = tf.shape(img)[1]
img_h = tf.shape(img)[2]
V = tf.reshape(V, tf.stack((-1, img_w, img_h, 1)))
return V