Skip to content

Commit

Permalink
fixed RPN
Browse files Browse the repository at this point in the history
  • Loading branch information
DragosBobolea committed Apr 28, 2020
1 parent 1f6a940 commit 6cea42b
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 162 deletions.
76 changes: 76 additions & 0 deletions backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import tensorflow as tf
keras = tf.keras
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, MaxPool2D


def _resnet_v1_50_block(input, base_depth, conv1stride=1):
x = Conv2D(base_depth, kernel_size=1, strides=conv1stride, padding='same')(input)
x = BatchNormalization()(x)
x = ReLU()(x)

x = Conv2D(base_depth, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = Conv2D(base_depth * 4, kernel_size=1, strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
if input.shape[3] != x.shape[3]:
residual = Conv2D(x.shape[3], kernel_size=1, strides=conv1stride, padding='same')(input)
residual = BatchNormalization()(residual)
residual = ReLU()(residual)
else:
residual = input
x = residual + x
return x

def _resnet_v1_50(input):
#block 1
x = Conv2D(64, kernel_size=7, strides=2, padding='same')(input)
x = BatchNormalization()(x)
x = ReLU()(x)
x = MaxPool2D(pool_size=2, strides=2, padding='same')(x)

#block 2
for i in range(3):
x = _resnet_v1_50_block(x, 64, conv1stride=1)

#block 3
x = _resnet_v1_50_block(x, 128, conv1stride=2)
for i in range(3):
x = _resnet_v1_50_block(x, 128, conv1stride=1)

#block 4
x = _resnet_v1_50_block(x, 256, conv1stride=2)
for i in range(5):
x = _resnet_v1_50_block(x, 256, conv1stride=1)

#block 5
x = _resnet_v1_50_block(x, 512, conv1stride=2)
for i in range(2):
x = _resnet_v1_50_block(x, 512, conv1stride=1)

return x

class ResNet50(Model):
def call(self, input, training=False):
result = _resnet_v1_50(input)
return result
if __name__ == '__main__':
from helpers import get_random_image

input = Input(shape=(None, None, 3))
output = _resnet_v1_50(input)
backbone = Model(inputs=[input], outputs=[output])
i = 0
for layer in backbone.layers:
if isinstance(layer, keras.layers.Conv2D) and i < 2:
layer.strides = (1,1)
i += 1
image, boxes = get_random_image(shape=(224,224))
image = np.expand_dims(image, axis=0)

result = backbone(image)
print(result)
2 changes: 2 additions & 0 deletions checkpoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_checkpoint_path: "model_weights.ckpt"
all_model_checkpoint_paths: "model_weights.ckpt"
4 changes: 2 additions & 2 deletions helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def get_random_image(shape):
y = i_h * h_slice + np.random.randint(1,h_slice)
x2 = 40 + x + np.random.randint(1, x - i_w * w_slice + 1)
y2 = 40 + y + np.random.randint(1, y - i_h * h_slice + 1)
image[y:y2, x:x2] = 1
boxes.append([x / float(shape[1] + 1), y/ float(shape[1] + 1), x2/ float(shape[1] + 1), y2/ float(shape[1] + 1)])
image[y:y2, x:x2] = np.random.ranf() * 0.5 + 0.5
boxes.append([x / float(shape[1]), y/ float(shape[0]), x2/ float(shape[1]), y2/ float(shape[0])])
boxes = np.clip(boxes, 0, 1)
image -= 0.5
return image, np.array(boxes)
Binary file added model_weights.ckpt.data-00000-of-00002
Binary file not shown.
Binary file added model_weights.ckpt.data-00001-of-00002
Binary file not shown.
Binary file added model_weights.ckpt.index
Binary file not shown.
80 changes: 52 additions & 28 deletions rpn_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@
import tensorflow as tf
keras = tf.keras
from helpers import intersection_over_union
# layers = []
class RegionProposalNetwork(keras.Model):
def __init__(self, backbone, scales, ratios):
def __init__(self, scales, ratios):
super(RegionProposalNetwork, self).__init__()
# hard-coded parameters (for now)
self.stride = 32
self.stride = 8
self.base_anchor_size = 64
self.positive_iou_threshold = 0.7
self.positive_iou_threshold = 0.5
self.negative_iou_threshold = 0.3
self.batch_size = 2
self.batch_size = 256
self.positives_ratio = 0.5
self.minibatch_positives_number = int(self.positives_ratio * self.batch_size)
self.minibatch_negatives_number = self.batch_size - self.minibatch_positives_number
self.max_number_of_predictions = 400
self.loss_classification_weight = 1
self.loss_regression_weight = 1
self.loss_regression_weight = 2
self.objectness_threshold = 0.9
self.pre_nms_top_k = 1000
self.nms_iou_threshold = 0.5
self.post_nms_top_k = 1000
# parameters
self.backbone = backbone
self.scales = scales
self.ratios = ratios
self.anchor_templates = self.__get_anchor_templates()
Expand All @@ -30,27 +34,39 @@ def __init__(self, backbone, scales, ratios):
self.box_classification = keras.layers.Conv2D(filters=2 * len(self.anchor_templates), kernel_size=1)
self.classification_softmax = keras.activations.softmax

@tf.function
def call(self, input, training=False):
x = self.conv1(input)
x = self.conv2(x)
output_regression = self.box_regression(x)
output_regression = tf.reshape(output_regression, (output_regression.shape[0], output_regression.shape[1], output_regression.shape[2], len(self.anchor_templates), 4))
output_regression_shape = tf.shape(output_regression)
output_regression = tf.reshape(output_regression, (output_regression_shape[0], output_regression_shape[1], output_regression_shape[2], len(self.anchor_templates), 4))

output_classification = self.box_classification(x)
output_classification = tf.reshape(output_classification, (output_classification.shape[0], output_classification.shape[1], output_classification.shape[2], len(self.anchor_templates), 2))
output_classification_shape = tf.shape(output_classification)
output_classification = tf.reshape(output_classification, (output_classification_shape[0], output_classification_shape[1], output_classification_shape[2], len(self.anchor_templates), 2))
output_classification = self.classification_softmax(output_classification, axis=4)
output = tf.concat((output_classification, output_regression),axis=4)
return output

def get_boxes(self, predictions):
rpn_output_classification = predictions[:,:,:,:,:2]
@tf.function
def get_boxes(self, predictions, image_shape):
anchors = self.generate_anchors(predictions, image_shape)

rpn_output_classification = predictions[:,:,:,:,1:2]
rpn_output_regression = predictions[:,:,:,:,2:]

positives_mask = tf.argmax(rpn_output_classification,axis=4) == 1
anchors = self.generate_anchors(predictions)
original_shape = tf.shape(rpn_output_classification)[:-1]
rpn_output_classification_flattened = tf.reshape(rpn_output_classification, [-1])
# _, positive_indices = tf.math.top_k(rpn_output_classification_flattened, k=20)
# positive_indices = tf.unravel_index(positive_indices, original_shape)
# positive_anchors = tf.gather_nd(anchors, tf.transpose(positive_indices))
# positive_regressions = tf.gather_nd(rpn_output_regression, tf.transpose(positive_indices))

positives_mask = tf.squeeze(rpn_output_classification >= self.objectness_threshold, axis=4)
positives_scores = tf.boolean_mask(rpn_output_classification, positives_mask)
positive_anchors = tf.boolean_mask(anchors, positives_mask)
positive_regressions = tf.boolean_mask(rpn_output_regression, positives_mask)


positive_anchors_coords = tf.unstack(positive_anchors, axis=1)
positive_anchors_left = positive_anchors_coords[0]
Expand All @@ -62,25 +78,25 @@ def get_boxes(self, predictions):
positive_anchors_w = (positive_anchors_right - positive_anchors_left)
positive_anchors_h = (positive_anchors_bottom - positive_anchors_top)


boxes_x = positive_anchors_x + positive_regressions[:,0] * positive_anchors_w
boxes_y = positive_anchors_y + positive_regressions[:,1] * positive_anchors_h
boxes_w = tf.math.exp(positive_regressions[:,2]) * positive_anchors_w
boxes_h = tf.math.exp(positive_regressions[:,3]) * positive_anchors_h



boxes = tf.stack([boxes_x - boxes_w/2, boxes_y - boxes_h/2, boxes_x + boxes_w/2, boxes_y + boxes_h/2],axis=1)

positives_scores = tf.reshape(positives_scores, [-1])
selected_indices = tf.image.non_max_suppression(boxes, positives_scores, max_output_size=self.post_nms_top_k, iou_threshold=self.nms_iou_threshold)
boxes = tf.gather(boxes, selected_indices)

return boxes

# @tf.function
def rpn_loss(self, ground_truths, rpn_output):
@tf.function
def rpn_loss(self, ground_truths, rpn_output, image_shape):
# identify positive anchors
# create a minibatch of anchors/ground truths
# apply L1 to minibatch
anchors = self.generate_anchors(rpn_output)
anchors = self.generate_anchors(rpn_output, image_shape)
positive_anchor_indices, positive_gt_indices, negative_anchor_indices = self.generate_minibatch(anchors, ground_truths)
ground_truth_targets = self.get_targets(anchors, ground_truths, positive_anchor_indices, positive_gt_indices, negative_anchor_indices)

Expand All @@ -105,6 +121,7 @@ def rpn_loss(self, ground_truths, rpn_output):
regression_loss = tf.losses.mean_absolute_error(positives_regression, ground_truth_targets)

return self.loss_regression_weight * tf.reduce_mean(regression_loss) + self.loss_classification_weight * tf.reduce_mean(classification_loss)
# return tf.reduce_mean(classification_loss)


'''
Expand All @@ -123,12 +140,12 @@ def rpn_loss(self, ground_truths, rpn_output):
def generate_minibatch(self, anchors, ground_truths):
positive_anchor_indices, positive_gt_indices, negative_anchor_indices = self.assign_anchors_to_ground_truths(anchors, ground_truths)
n_positives = tf.minimum(tf.shape(positive_anchor_indices)[2], self.minibatch_positives_number)
n_negatives = tf.minimum(tf.shape(negative_anchor_indices)[2], self.batch_size - n_positives)
# n_negatives = tf.minimum(tf.shape(negative_anchor_indices)[2], self.batch_size - n_positives)
n_negatives = tf.minimum(tf.shape(negative_anchor_indices)[2], int(float(n_positives) / self.positives_ratio))

indices = tf.range(tf.shape(positive_anchor_indices)[2])
indices = tf.random.shuffle(indices)
indices = tf.slice(indices, [0], [n_positives])

positive_anchor_indices = tf.gather(positive_anchor_indices, indices,axis=2)
positive_gt_indices = tf.gather(positive_gt_indices, indices,axis=1)

Expand Down Expand Up @@ -203,10 +220,12 @@ def __get_anchor_templates(self):
anchors: tensor of shape (1, height, width, num_anchors, 4)
'''
# @tf.function
def generate_anchors(self, feature_map):
def generate_anchors(self, feature_map, image_shape):
# TODO support minibatch by tiling anchors on first dimension
feature_map_shape = tf.shape(feature_map)
assert feature_map.shape[0] == 1
if tf.size(image_shape) == 4:
image_shape = image_shape[1:]
self.stride = tf.cast(image_shape[0] / feature_map_shape[1],tf.int32)
vertical_stride = tf.range(0,feature_map_shape[1])
vertical_stride = tf.tile(vertical_stride,[feature_map_shape[2]])
vertical_stride = tf.reshape(vertical_stride, (feature_map_shape[2], feature_map_shape[1]))
Expand All @@ -217,14 +236,18 @@ def generate_anchors(self, feature_map):
horizontal_stride = tf.reshape(horizontal_stride, (feature_map_shape[1], feature_map_shape[2]))

centers_xyxy = tf.stack([horizontal_stride, vertical_stride, horizontal_stride, vertical_stride], axis=2)

centers_xyxy = self.stride * centers_xyxy
centers_xyxy = tf.cast(centers_xyxy,tf.float32)
centers_xyxy = tf.cast(centers_xyxy, tf.float32) + 0.5
centers_xyxy = float(self.stride) * centers_xyxy

centers_xyxy = tf.tile(centers_xyxy,[1,1,self.anchor_templates.shape[0]])
centers_xyxy = tf.reshape(centers_xyxy, (feature_map_shape[1], feature_map_shape[2], self.anchor_templates.shape[0], 4))
anchors = centers_xyxy + self.anchor_templates
# TODO properly convert to normalized

normalize = tf.cast(tf.gather(image_shape, [1,0,1,0]), tf.float32)
anchors /= normalize
anchors = tf.expand_dims(anchors,axis=0)

return anchors

'''
Expand All @@ -242,6 +265,7 @@ def generate_anchors(self, feature_map):
# @tf.function
def assign_anchors_to_ground_truths(self, anchors, ground_truths):
anchors = tf.cast(anchors, tf.float32)
anchors_shape = tf.shape(anchors)
ground_truths = tf.cast(ground_truths, tf.float32)
flattened_ground_truths = tf.reshape(ground_truths, (-1,4))
flattened_anchors = tf.reshape(anchors, (-1,4))
Expand All @@ -258,13 +282,13 @@ def assign_anchors_to_ground_truths(self, anchors, ground_truths):
positive_anchor_indices_flattened = tf.reshape(positive_anchor_indices_flattened, [-1])
positive_gt_indices = tf.gather(ground_truth_per_anchor, positive_anchor_indices_flattened)
positive_gt_indices = tf.expand_dims(positive_gt_indices, axis=0)
positive_anchor_indices = tf.unravel_index(positive_anchor_indices_flattened, (anchors.shape[1], anchors.shape[2], anchors.shape[3]))
positive_anchor_indices = tf.unravel_index(positive_anchor_indices_flattened, (anchors_shape[1], anchors_shape[2], anchors_shape[3]))
positive_anchor_indices = tf.expand_dims(positive_anchor_indices, axis=0)

negative_anchors = tf.less_equal(max_iou_per_anchor, self.negative_iou_threshold)
negative_anchor_indices_flattened = tf.where(negative_anchors)
negative_anchor_indices_flattened = tf.reshape(negative_anchor_indices_flattened, [-1])
negative_anchor_indices = tf.unravel_index(negative_anchor_indices_flattened, (anchors.shape[1], anchors.shape[2], anchors.shape[3]))
negative_anchor_indices = tf.unravel_index(negative_anchor_indices_flattened, (anchors_shape[1], anchors_shape[2], anchors_shape[3]))
negative_anchor_indices = tf.expand_dims(negative_anchor_indices, axis=0)
return positive_anchor_indices, positive_gt_indices, negative_anchor_indices

Loading

0 comments on commit 6cea42b

Please sign in to comment.