diff --git a/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py index d90372818d..239de000f4 100644 --- a/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py +++ b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py @@ -72,27 +72,29 @@ class FeaturePyramid(keras.layers.Layer): Example: ```python + images = keras.layers.Input( + image_shape, + name="images", + ) + extractor_levels= ["P2", "P3", "P4", "P5"] - inp = keras.layers.Input((384, 384, 3)) - backbone = keras.applications.EfficientNetB0( - input_tensor=inp, - include_top=False + backbone = keras_cv.models.ResNetV2Backbone.from_preset( + "resnet50_v2_imagenet", include_rescaling=True ) - layer_names = ['block2b_add', - 'block3b_add', - 'block5c_add', - 'top_activation' + + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels ] - backbone_outputs = {} - for i, layer_name in enumerate(layer_names): - backbone_outputs[i+2] = backbone.get_layer(layer_name).output + feature_extractor = get_feature_extractor( + backbone, + extractor_layer_names, + extractor_levels + ) + feature_pyramid = FeaturePyramid(min_level=2, max_level=5) - # output_dict is a dict with 2, 3, 4, 5 as keys - output_dict = keras_cv.layers.FeaturePyramid( - min_level=2, - max_level=5 - )(backbone_outputs) + backbone_outputs = feature_extractor(images) + feature_map = feature_pyramid(backbone_outputs) ``` """ @@ -146,15 +148,7 @@ def __init__( else: self._validate_user_layers(output_layers, "output_layers") self.output_layers = output_layers - # this layer is cutom to Faster R-CNN - self.final_conv = keras.layers.Conv2D( - self.num_channels, - kernel_size=3, - strides=1, - padding="same", - name=f"output_P{self.max_level+1}", - ) self.max_pool = keras.layers.MaxPool2D() # the same upsampling layer is used for all levels @@ -189,12 +183,13 @@ def call(self, features): def build_feature_pyramid(self, input_features): # To illustrate the connection/topology, the basic flow for a FPN with - # level 2, 3, 4, 5 is like below: - # - # - # input_l5 -> max_pool_2d_l6 -------> conv2d_3x3_l6 -> output_l6 - # | - # | + # level 3, 4, 5 is like below: + # output_l6 + # ^ + # | + # max_pool_2d + # ^ + # | # input_l5 -> conv2d_1x1_l5 ----V---> conv2d_3x3_l5 -> output_l5 # V # upsample2d @@ -208,37 +203,41 @@ def build_feature_pyramid(self, input_features): # upsample2d # V # input_l2 -> conv2d_1x1_l2 -> Add -> conv2d_3x3_l2 -> output_l2 - output_features = {} - for level in range(self.max_level, self.min_level - 1, -1): - output = self.lateral_layers[f"P{level}"]( - input_features[f"P{level}"] - ) - if level < self.max_level: + reversed_levels = list(sorted(input_features.keys(), reverse=True)) + + for i in range(self.max_level, self.min_level - 1, -1): + level = f"P{i}" + print(level) + print(input_features[level]) + print(self.lateral_layers.keys()) + output = self.lateral_layers[level](input_features[level]) + if i < self.max_level: # for the top most output, it doesn't need to merge with any # upper stream outputs - upstream_output = self.top_down_op( - output_features[f"P{level + 1}"] - ) + upstream_output = self.top_down_op(output_features[f"P{i + 1}"]) output = self.merge_op([output, upstream_output]) - output_features[f"P{level}"] = output + output_features[level] = output - output_features[f"P{self.max_level+1}"] = self.final_conv( - self.max_pool(input_features[f"P{self.max_level}"]) - ) # Post apply the output layers so that we don't leak them to the down # stream level - for level in range(self.max_level, self.min_level - 1, -1): - output_features[f"P{level}"] = self.output_layers[f"P{level}"]( - output_features[f"P{level}"] + for level in reversed_levels: + output_features[level] = self.output_layers[level]( + output_features[level] ) + output_features[f"P{self.max_level + 1}"] = self.max_pool( + output_features[f"P{self.max_level}"] + ) return output_features def get_config(self): - config = super().get_config() - config["min_level"] = self.min_level - config["max_level"] = self.max_level - config["num_channels"] = self.num_channels - config["lateral_layers"] = self.lateral_layers_passed - config["output_layers"] = self.output_layers_passed + config = { + "min_level": self.min_level, + "max_level": self.max_level, + "num_channels": self.num_channels, + "lateral_layers": self.lateral_layers_passed, + "output_layers": self.output_layers_passed, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items()))