From 843401692fc7758595c17ae447c0aea094277b18 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Fri, 6 Sep 2024 12:10:48 -0700 Subject: [PATCH] Move files to maskrcnn folder and addressed all the required changes --- .../faster_rcnn/faster_rcnn.py | 2 +- .../mask_rcnn/faster_rcnn_backbone.py | 812 ++++++++++++++++++ .../object_detection/mask_rcnn/mask_head.py | 17 +- .../object_detection/mask_rcnn/mask_rcnn.py | 89 +- .../mask_rcnn/non_max_suppression.py | 566 ++++++++++++ .../object_detection/mask_rcnn/roi_sampler.py | 318 +++++++ 6 files changed, 1757 insertions(+), 47 deletions(-) create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/faster_rcnn_backbone.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py diff --git a/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py index f3ede6b75a..a73f40349e 100644 --- a/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py @@ -804,4 +804,4 @@ def unpack_input(data): if type(data) is dict: return data["images"], data["bounding_boxes"] else: - return data \ No newline at end of file + return data diff --git a/keras_cv/src/models/object_detection/mask_rcnn/faster_rcnn_backbone.py b/keras_cv/src/models/object_detection/mask_rcnn/faster_rcnn_backbone.py new file mode 100644 index 0000000000..07b1d0a220 --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/faster_rcnn_backbone.py @@ -0,0 +1,812 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tree + +from keras_cv.src import losses +#from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.bounding_box import convert_format +from keras_cv.src.bounding_box.converters import _decode_deltas_to_boxes +from keras_cv.src.bounding_box.utils import _clip_boxes +from keras_cv.src.layers.object_detection.anchor_generator import ( + AnchorGenerator, +) +from keras_cv.src.layers.object_detection.box_matcher import BoxMatcher +from keras_cv.src.layers.object_detection.roi_align import ROIAligner +from keras_cv.src.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.src.layers.object_detection.rpn_label_encoder import ( + RpnLabelEncoder, +) +from keras_cv.src.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.src.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.src.models.object_detection.faster_rcnn import RPNHead +from keras_cv.src.models.object_detection.mask_rcnn.non_max_suppression import ( + NonMaxSuppression, +) +from keras_cv.src.models.object_detection.mask_rcnn.roi_sampler import ( + ROISampler, +) +from keras_cv.src.models.task import Task +from keras_cv.src.utils.train import get_feature_extractor + +BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +# @keras_cv_export( +# [ +# "keras_cv.models.FasterRCNN", +# "keras_cv.models.object_detection.FasterRCNN", +# ] +# ) +class FasterRCNN(Task): + """A Keras model implementing the Faster R-CNN architecture. + + This model is compatible with Keras 3 only. Implements the Faster R-CNN architecture + for object detection. The constructor requires `num_classes`, `bounding_box_format`, + and a backbone. Optionally, a custom label encoder, and prediction decoder + may be provided. + + Example: + ```python + images = np.ones((1, 512, 512, 3)) + labels = { + "boxes": tf.cast([ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], dtype=tf.float32), + "classes": tf.cast([[1, 1, 1]], dtype=tf.float32), + } + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + + # Evaluate model without box decoding and NMS + model(images) + + # Prediction with box decoding and NMS + model.predict(images) + + # Train model + model.compile( + optimizer=keras.optimizers.SGD(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + model.fit(images, labels, batch_size=1) + ``` + + Args: + backbone: `keras.Model`. If the default `feature_pyramid` is used, + must implement the `pyramid_level_inputs` property with keys "P3", "P4", + and "P5" and layer names as values. A somewhat sensible backbone + to use in many cases is the: + `keras_cv.models.ResNetBackbone.from_preset("resnet50_imagenet")` + num_classes: the number of classes in your dataset excluding the + background class. Classes should be represented by integers in the + range [1, num_classes]. + bounding_box_format: The format of bounding boxes of input dataset. + Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. If + provided, the anchor generator will be passed to both the + `label_encoder` and the `prediction_decoder`. Only to be used when + both `label_encoder` and `prediction_decoder` are both `None`. + Defaults to an anchor generator with the parameterization: + `strides=[2**i for i in range(3, 8)]`, + `scales=[2**x for x in [0, 1 / 3, 2 / 3]]`, + `sizes=[32.0, 64.0, 128.0, 256.0, 512.0]`, + and `aspect_ratios=[0.5, 1.0, 2.0]`. + anchor_scales: (Optional) list of anchor scales for + default anchor generator. + anchor_aspect_ratios: (Optional) list of anchor aspect ratios for + default anchor generator. + feature_pyramid: (Optional) A `keras.layers.Layer` that produces + a list of 4D feature maps (batch dimension included) + when called on the pyramid-level outputs of the `backbone`. + If not provided, the reference implementation from the paper will be used. + fpn_min_level: (Optional) the minimum level of the feature pyramid. + fpn_max_level: (Optional) the maximum level of the feature pyramid. + rpn_head: (Optional) A `keras.Layer` that performs regression and + classification(background or foreground) of the bounding boxes. + If not provided, a simple ConvNet with 3 layers will be used. + rpn_label_encoder_posistive_threshold: (Optional) the float threshold to set an + anchor to positive match to gt box. Values above it are positive matches. + rpn_label_encoder_negative_threshold: (Optional) the float threshold to set an + anchor to negative match to gt box. Values below it are negative matches. + rpn_label_encoder_samples_per_image: (Optional) for each image, the number of + positive and negative samples to generate. + rpn_label_encoder_positive_fraction: (Optional) the fraction of positive samples to the total samples. + rcnn_head: (Optional) A `keras.Layer` that performs regression and + classification(final prediction) of the bounding boxes. + If not provided, a simple network with 2 dense layers with + box head and regression head will be used. + label_encoder: (Optional) a keras.Layer that accepts an image Tensor, a + bounding box Tensor and a bounding box class Tensor to its `call()` + method, and returns RetinaNet training targets. By default, a + KerasCV standard `RpnLabelEncoder` is created and used. + Results of this object's `call()` method are passed to the `loss` + object for `rpn_box_loss` and `rpn_classification_loss` the `y_true` + argument. + prediction_decoder: (Optional) A `keras.layers.Layer` that is + responsible for transforming RetinaNet predictions into usable + bounding box Tensors. If not provided, a default is provided. The + default `prediction_decoder` layer is a + `keras_cv.layers.MultiClassNonMaxSuppression` layer, which uses + a Non-Max Suppression for box pruning. + num_max_detections: the maximum detections to consider after nms is applied. A + large number may trigger significant memory overhead, defaults to 100. + """ # noqa: E501 + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + anchor_generator=None, + anchor_scales=[1], + anchor_aspect_ratios=[0.5, 1.0, 2.0], + feature_pyramid=None, + fpn_min_level=2, + fpn_max_level=5, + rpn_head=None, + rpn_filters=256, + rpn_kernel_size=3, + rpn_label_encoder_posistive_threshold=0.7, + rpn_label_encoder_negative_threshold=0.3, + rpn_label_encoder_samples_per_image=256, + rpn_label_encoder_positive_fraction=0.5, + rcnn_head=None, + num_sampled_rois=512, + label_encoder=None, + prediction_decoder=None, + num_max_decoder_detections=100, + *args, + **kwargs, + ): + # Backbone + extractor_levels = [ + f"P{level}" for level in range(fpn_min_level, fpn_max_level + 1) + ] + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + feature_extractor = get_feature_extractor( + backbone, extractor_layer_names, extractor_levels + ) + + # Feature Pyramid + feature_pyramid = feature_pyramid or FeaturePyramid( + min_level=fpn_min_level, max_level=fpn_max_level + ) + + # Anchors + anchor_generator = ( + anchor_generator + or FasterRCNN.default_anchor_generator( + fpn_min_level, + fpn_max_level + 1, + anchor_scales, + anchor_aspect_ratios, + "yxyx", + ) + ) + + # RPN Head + num_anchors_per_location = len(anchor_scales) * len( + anchor_aspect_ratios + ) + rpn_head = rpn_head or RPNHead( + num_anchors_per_location=num_anchors_per_location, + num_filters=rpn_filters, + kernel_size=rpn_kernel_size, + ) + + # RoI Generator + roi_generator = ROIGenerator( + bounding_box_format="yxyx", + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), + nms_from_logits=True, + name="roi_generator", + ) + + # RoI Align + roi_aligner = ROIAligner(bounding_box_format="yxyx", name="roi_align") + + # R-CNN Head + rcnn_head = rcnn_head or RCNNHead(num_classes, name="rcnn_head") + + # Begin construction of forward pass + image_shape = feature_extractor.input_shape[1:] + if None in image_shape: + raise ValueError( + "Found `None` in image_shape, to build anchors `image_shape`" + "is required without any `None`. Make sure to pass " + "`image_shape` to the backbone preset while passing to" + "the Faster R-CNN detector." + ) + + images = keras.layers.Input( + image_shape, + name="images", + ) + + # Forward through backbone + backbone_outputs = feature_extractor(images) + + # Forward through FPN decoder + feature_map = feature_pyramid(backbone_outputs) + + # [P2, P3, P4, P5, P6] -> ([BS, num_anchors, 4], [BS, num_anchors, 1]) + # Pass through RPN Head + rpn_boxes, rpn_scores = rpn_head(feature_map) + + # Reshape and Concatenate all the output boxes of all levels + for lvl in rpn_boxes: + rpn_boxes[lvl] = keras.layers.Reshape(target_shape=(-1, 4))( + rpn_boxes[lvl] + ) + + for lvl in rpn_scores: + rpn_scores[lvl] = keras.layers.Reshape(target_shape=(-1, 1))( + rpn_scores[lvl] + ) + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )(tree.flatten(rpn_scores)) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + tree.flatten(rpn_boxes) + ) + + anchors = anchor_generator(image_shape=image_shape) + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, roi_scores = roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, "yxyx", image_shape) + + feature_map = roi_aligner(features=feature_map, boxes=rois) + + # Reshape the feature map [BS, H*W*K] + feature_map = keras.layers.Reshape( + target_shape=( + rois.shape[1], + (roi_aligner.target_size**2) * rpn_head.num_filters, + ) + )(feature_map) + + # Pass final feature map to RCNN Head for predictions + box_pred, cls_pred = rcnn_head(feature_map=feature_map) + + box_pred = keras.layers.Concatenate(axis=1, name="box")([box_pred]) + cls_pred = keras.layers.Concatenate(axis=1, name="classification")( + [cls_pred] + ) + + inputs = {"images": images} + outputs = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + self.bounding_box_format = bounding_box_format + self.anchor_generator = anchor_generator + self.num_classes = num_classes + self.feature_extractor = feature_extractor + self.backbone = backbone + self.feature_pyramid = feature_pyramid + self.rpn_head = rpn_head + self.label_encoder = label_encoder or RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format=bounding_box_format, + positive_threshold=rpn_label_encoder_posistive_threshold, + negative_threshold=rpn_label_encoder_negative_threshold, + samples_per_image=rpn_label_encoder_samples_per_image, + positive_fraction=rpn_label_encoder_positive_fraction, + box_variance=BOX_VARIANCE, + ) + self.roi_generator = roi_generator + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = ROISampler( + roi_bounding_box_format="yxyx", + gt_bounding_box_format=bounding_box_format, + roi_matcher=self.box_matcher, + num_sampled_rois=num_sampled_rois, + ) + + self.roi_aligner = roi_aligner + self.rcnn_head = rcnn_head + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + max_detections=num_max_decoder_detections, + ) + self.build(backbone.input_shape) + + def compile( + self, + rpn_box_loss=None, + rpn_classification_loss=None, + box_loss=None, + classification_loss=None, + weight_decay=0.0001, + loss=None, + metrics=None, + **kwargs, + ): + if loss is not None: + raise ValueError( + "`FasterRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss` and `classification_loss`. " + "`loss` will be ignored during training." + ) + if ( + rpn_box_loss is None + or rpn_classification_loss is None + or box_loss is None + or classification_loss is None + ): + raise ValueError( + "`FasterRCNN` expects all of `rpn_box_loss`, " + "`rpn_classification_loss`," + "`box_loss`, and " + "`classification_loss` to be not `None`." + ) + + rpn_box_loss = _parse_box_loss(rpn_box_loss) + rpn_classification_loss = _parse_rpn_classification_loss( + rpn_classification_loss + ) + + if hasattr(rpn_classification_loss, "from_logits"): + if not rpn_classification_loss.from_logits: + raise ValueError( + "FasterRCNN.compile() expects `from_logits` to be True for " + "`rpn_classification_loss`. Got " + "`rpn_classification_loss.from_logits=" + f"{rpn_classification_loss.from_logits}`" + ) + box_loss = _parse_box_loss(box_loss) + classification_loss = _parse_classification_loss(classification_loss) + + if hasattr(classification_loss, "from_logits"): + if not classification_loss.from_logits: + raise ValueError( + "FasterRCNN.compile() expects `from_logits` to be True for " + "`classification_loss`. Got " + "`classification_loss.from_logits=" + f"{classification_loss.from_logits}`" + ) + if hasattr(box_loss, "bounding_box_format"): + if box_loss.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Wrong `bounding_box_format` passed to `box_loss` in " + "`FasterRCNN.compile()`. Got " + "`box_loss.bounding_box_format=" + f"{box_loss.bounding_box_format}`, want " + "`box_loss.bounding_box_format=" + f"{self.bounding_box_format}`" + ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.weight_decay = weight_decay + losses = { + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + "box": self.box_loss, + "classification": self.cls_loss, + } + self._has_user_metrics = metrics is not None and len(metrics) != 0 + self._user_metrics = metrics + super().compile(loss=losses, **kwargs) + + def compute_loss( + self, x, y, y_pred, sample_weight, training=True, **kwargs + ): + + # 1. Unpack the inputs + images = x + gt_boxes = y["boxes"] + if ops.ndim(y["classes"]) != 2: + raise ValueError( + "Expected 'classes' to be a Tensor of rank 2. " + f"Got y['classes'].shape={ops.shape(y['classes'])}." + ) + + gt_classes = y["classes"] + gt_classes = ops.expand_dims(gt_classes, axis=-1) + + # Generate Anchors and Generate RPN Targets + local_batch = ops.shape(images)[0] + image_shape = ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) + + # Label with the anchors -- exclusive to compute_loss + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.label_encoder( + anchors_dict=ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + gt_boxes=gt_boxes, + gt_classes=gt_classes, + ) + + # Computing the weights + rpn_box_weights /= ( + self.label_encoder.samples_per_image * local_batch * 0.25 + ) + rpn_cls_weights /= self.label_encoder.samples_per_image * local_batch + + # Call Backbone, FPN and RPN Head + backbone_outputs = self.feature_extractor(images) + feature_map = self.feature_pyramid(backbone_outputs) + rpn_boxes, rpn_scores = self.rpn_head(feature_map) + + for lvl in rpn_boxes: + rpn_boxes[lvl] = keras.layers.Reshape(target_shape=(-1, 4))( + rpn_boxes[lvl] + ) + + for lvl in rpn_scores: + rpn_scores[lvl] = keras.layers.Reshape(target_shape=(-1, 1))( + rpn_scores[lvl] + ) + + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )(tree.flatten(rpn_scores)) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + tree.flatten(rpn_boxes) + ) + + # Generate RoI's and RoI Sampling + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=training + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + + # Stop gradient from flowing into the ROI + # -- exclusive to compute_loss + rois = ops.stop_gradient(rois) + + # Sample the ROIS -- exclusive to compute_loss + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + ) = self.roi_sampler(rois, gt_boxes, gt_classes) + + cls_targets = ops.squeeze(cls_targets, axis=-1) + cls_weights = ops.squeeze(cls_weights, axis=-1) + + # Box and class weights -- exclusive to compute loss + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * local_batch + cls_targets = ops.one_hot(cls_targets, num_classes=self.num_classes + 1) + + # Call RoI Aligner and RCNN Head + feature_map = self.roi_aligner(features=feature_map, boxes=rois) + + # [BS, H*W*K] + feature_map = ops.reshape( + feature_map, + newshape=ops.shape(rois)[:2] + (-1,), + ) + + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + box_pred, cls_pred = self.rcnn_head(feature_map=feature_map) + + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + } + + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=weights, **kwargs + ) + + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if type(outputs) is tuple: + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + else: + return self.decode_predictions(outputs, args[-1]) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and FasterRCNN to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def decode_predictions(self, predictions, images): + image_shape = ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) + rpn_boxes, rpn_scores = ( + predictions["rpn_box"], + predictions["rpn_classification"], + ) + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=False + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + box_pred, cls_pred = predictions["box"], predictions["classification"] + + # box_pred is on "center_yxhw" format, convert to target format. + box_pred = _decode_deltas_to_boxes( + anchors=rois, + boxes_delta=box_pred, + anchor_format=self.roi_aligner.bounding_box_format, + box_format=self.bounding_box_format, + variance=BOX_VARIANCE, + image_shape=image_shape, + ) + + box_pred = convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + image_shape=image_shape, + ) + cls_pred = ops.softmax(cls_pred) + cls_pred = ops.slice( + cls_pred, + start_indices=[0, 0, 1], + shape=[cls_pred.shape[0], cls_pred.shape[1], cls_pred.shape[2] - 1], + ) + + y_pred = self.prediction_decoder( + box_pred, + cls_pred, + mask_prediction=predictions.get("segmask"), + image_shape=image_shape, + ) + + y_pred["classes"] = ops.where( + y_pred["classes"] == -1, -1, y_pred["classes"] + 1 + ) + + y_pred["boxes"] = convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + image_shape=image_shape, + ) + return y_pred + + def compute_metrics(self, x, y, y_pred, sample_weight): + metrics = {} + metrics.update(super().compute_metrics(x, {}, {}, sample_weight={})) + + if not self._has_user_metrics: + return metrics + + y_pred = self.decode_predictions(y_pred, x) + + for metric in self._user_metrics: + metric.update_state(y, y_pred, sample_weight=sample_weight) + + for metric in self._user_metrics: + result = metric.result() + if isinstance(result, dict): + metrics.update(result) + else: + metrics[metric.name] = result + return metrics + + @staticmethod + def default_anchor_generator( + min_level, max_level, scales, aspect_ratios, bounding_box_format + ): + strides = {f"P{i}": 2**i for i in range(min_level, max_level + 1)} + sizes = {f"P{i}": 2 ** (3 + i) for i in range(min_level, max_level + 1)} + return AnchorGenerator( + bounding_box_format=bounding_box_format, + sizes=sizes, + aspect_ratios=aspect_ratios, + scales=scales, + strides=strides, + clip_boxes=True, + name="anchor_generator", + ) + + def get_config(self): + return { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "backbone": keras.saving.serialize_keras_object(self.backbone), + "label_encoder": keras.saving.serialize_keras_object( + self.label_encoder + ), + "rpn_head": keras.saving.serialize_keras_object(self.rpn_head), + "prediction_decoder": self._prediction_decoder, + "rcnn_head": self.rcnn_head, + } + + @classmethod + def from_config(cls, config): + if "rpn_head" in config and isinstance(config["rpn_head"], dict): + config["rpn_head"] = keras.layers.deserialize(config["rpn_head"]) + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + if "rcnn_head" in config and isinstance(config["rcnn_head"], dict): + config["rcnn_head"] = keras.layers.deserialize(config["rcnn_head"]) + + return super().from_config(config) + + +def _parse_box_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + # case insensitive comparison + if loss.lower() == "smoothl1": + return losses.SmoothL1Loss(l1_cutoff=1.0, reduction="sum") + if loss.lower() == "huber": + return keras.losses.Huber(reduction="sum") + + raise ValueError( + "Expected `box_loss` to be either a Keras Loss, " + f"callable, or the string 'SmoothL1'. Got loss={loss}." + ) + + +def _parse_rpn_classification_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + if loss.lower() == "binarycrossentropy": + return keras.losses.BinaryCrossentropy( + reduction="sum", from_logits=True + ) + + raise ValueError( + f"Expected `rpn_classification_loss` to be either BinaryCrossentropy" + f" loss callable, or the string 'BinaryCrossentropy'. Got loss={loss}." + ) + + +def _parse_classification_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + # case insensitive comparison + if loss.lower() == "focal": + return losses.FocalLoss(reduction="sum", from_logits=True) + if loss.lower() == "categoricalcrossentropy": + return keras.losses.CategoricalCrossentropy( + reduction="sum", from_logits=True + ) + + raise ValueError( + f"Expected `classification_loss` to be either a Keras Loss, " + f"callable, or the string 'Focal', CategoricalCrossentropy'. " + f"Got loss={loss}." + ) + + +def unpack_input(data): + if type(data) is dict: + return data["images"], data["bounding_boxes"] + else: + return data \ No newline at end of file diff --git a/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py b/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py index 4e6a245965..c266af4657 100644 --- a/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py +++ b/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py @@ -17,20 +17,21 @@ @keras_cv_export( - "keras_cv.models.faster_rcnn.MaskHead", - package="keras_cv.models.faster_rcnn", + "keras_cv.models.mask_rcnn.MaskHead", + package="keras_cv.models.mask_rcnn", ) class MaskHead(keras.layers.Layer): """A Keras layer implementing the R-CNN Mask Head. - The architecture is adopted from Matterport's Mask R-CNN implementation. + The architecture is adopted from Matterport's Mask R-CNN implementation + https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/model.py. Args: num_classes: The number of object classes that are being detected. conv_dims: (Optional) a list of integers specifying the number of - filters for each convolutional layer. Defaults to []. - deconv_dim: (Optional) the numver of filters to use in the upsampling - convolutional layer. + filters for each convolutional layer. Defaults to [256, 256]. + deconv_dim: (Optional) the number of filters to use in the upsampling + convolutional layer. Defaults to 256. """ def __init__( @@ -70,7 +71,7 @@ def __init__( ) # we do not use a final sigmoid activation, since we use # from_logits=True during training - self.segmask_output = keras.layers.TimeDistributed( + self.segmentation_mask_output = keras.layers.TimeDistributed( keras.layers.Conv2D( num_classes + 1, kernel_size=1, @@ -89,7 +90,7 @@ def call(self, feature_map, training=False): def build(self, input_shape): intermediate_shape = input_shape - for idx, conv_dim in self.conv_dims: + for idx, conv_dim in enumerate(self.conv_dims): self.layers[idx * 3].build(intermediate_shape) intermediate_shape = tuple(intermediate_shape[:-1]) + (conv_dim,) self.layers[idx * 3 + 1].build(intermediate_shape) diff --git a/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py b/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py index f47a95f17f..0f616f6434 100644 --- a/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py +++ b/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py @@ -13,7 +13,6 @@ # limitations under the License. -import tensorflow as tf import tree from keras_cv.src.api_export import keras_cv_export @@ -22,23 +21,25 @@ from keras_cv.src.bounding_box import convert_format from keras_cv.src.bounding_box.converters import _decode_deltas_to_boxes from keras_cv.src.bounding_box.utils import _clip_boxes -from keras_cv.src.layers.object_detection.roi_sampler import ROISampler -from keras_cv.src.models.object_detection.faster_rcnn import MaskHead -from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( +from keras_cv.src.models.object_detection.mask_rcnn.faster_rcnn_backbone import ( BOX_VARIANCE, ) -from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( +from keras_cv.src.models.object_detection.mask_rcnn.faster_rcnn_backbone import ( _parse_box_loss, ) -from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( +from keras_cv.src.models.object_detection.mask_rcnn.faster_rcnn_backbone import ( _parse_classification_loss, ) -from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( +from keras_cv.src.models.object_detection.mask_rcnn.faster_rcnn_backbone import ( _parse_rpn_classification_loss, ) -from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( +from keras_cv.src.models.object_detection.mask_rcnn.faster_rcnn_backbone import ( unpack_input, ) +from keras_cv.src.models.object_detection.mask_rcnn.mask_head import MaskHead +from keras_cv.src.models.object_detection.mask_rcnn.roi_sampler import ( + ROISampler, +) from keras_cv.src.models.task import Task @@ -58,12 +59,11 @@ class MaskRCNN(Task): This model is compatible with Keras 3 only. Args: - backbone: `keras.Model`. A FasterRCNN model that is used for - object detection. + backbone: `keras.Model`. A FasterRCNN backbone model. mask_head: (Optional) A `keras.Layer` that performs regression of - the segmentation masks. - If not provided, a network with 2 convolutional layers, an - upsampling layer and a class-specific layer will be used. + the segmentation masks. If not provided, a network with + 2 convolutional layers, an upsampling layer and a class-specific + layer will be used. """ # noqa: E501 def __init__(self, backbone, mask_head=None, **kwargs): @@ -94,7 +94,7 @@ def __init__(self, backbone, mask_head=None, **kwargs): roi_bounding_box_format="yxyx", gt_bounding_box_format=backbone.bounding_box_format, roi_matcher=backbone.box_matcher, - num_sampled_rois=64, + num_sampled_rois=backbone.roi_sampler.num_sampled_rois, mask_shape=(14, 14), ) @@ -124,7 +124,7 @@ def compile( if hasattr(rpn_classification_loss, "from_logits"): if not rpn_classification_loss.from_logits: raise ValueError( - "FasterRCNN.compile() expects `from_logits` to be True for " + "MaskRCNN.compile() expects `from_logits` to be True for " "`rpn_classification_loss`. Got " "`rpn_classification_loss.from_logits=" f"{rpn_classification_loss.from_logits}`" @@ -136,7 +136,7 @@ def compile( if hasattr(classification_loss, "from_logits"): if not classification_loss.from_logits: raise ValueError( - "FasterRCNN.compile() expects `from_logits` to be True for " + "MaskRCNN.compile() expects `from_logits` to be True for " "`classification_loss`. Got " "`classification_loss.from_logits=" f"{classification_loss.from_logits}`" @@ -148,7 +148,7 @@ def compile( ): raise ValueError( "Wrong `bounding_box_format` passed to `box_loss` in " - "`FasterRCNN.compile()`. Got " + "`MaskRCNN.compile()`. Got " "`box_loss.bounding_box_format=" f"{box_loss.bounding_box_format}`, want " "`box_loss.bounding_box_format=" @@ -368,11 +368,12 @@ def test_step(self, *args): def decode_predictions(self, predictions, images): y_pred = self.backbone.decode_predictions(predictions, images) image_shape = ops.shape(images)[1:] + segmask_pred = ops.sigmoid(y_pred["segmask"]) y_pred["segmask"] = self.decode_segmentation_masks( - segmask_pred=y_pred["segmask"], + segmask_pred=segmask_pred, class_pred=y_pred["classes"], decoded_boxes=y_pred["boxes"], - bbox_foramt=self.bounding_box_format, + bbox_format=self.backbone.bounding_box_format, image_shape=image_shape, ) return y_pred @@ -390,34 +391,46 @@ def _resize_and_pad_mask( num_rois = ops.shape(segmask_pred)[0] image_height, image_width = image_shape[:2] - # Reshape segmask_pred to (num_rois, mask_height, mask_width, 1) to - # use with image resizing functions - segmask_pred = ops.expand_dims(segmask_pred, 1) - # Initialize a list to store the padded masks padded_masks_list = [] # Iterate over the batch and place the resized masks into the correct # position for i in range(num_rois): - if class_pred[i] == -1: - continue - y1, x1, y2, x2 = ops.maximum(ops.cast(decoded_boxes[i], "int32"), 0) - y1, y2 = ops.minimum([y1, y2], image_height) - x1, x2 = ops.minimum([x1, x2], image_width) + bounding_box = ops.maximum(ops.cast(decoded_boxes[i], "int32"), 0) + bounding_box = ops.minimum( + bounding_box, [image_height, image_width] * 2 + ) + y1, x1, y2, x2 = ops.unstack(bounding_box) box_height = y2 - y1 box_width = x2 - x1 - # Resize the mask to the size of the bounding box - resized_mask = tf.image.resize( - segmask_pred[i], size=(box_height, box_width) - ) - - # Place the resized mask into the correct position in the final mask - padded_mask = tf.image.pad_to_bounding_box( - resized_mask, y1, x1, image_height, image_width + def do_resize(): + # Resize the mask to the size of the bounding box + resized_mask = ops.image.resize( + segmask_pred[i], size=(box_height, box_width) + ) + resized_mask = ops.squeeze(resized_mask, axis=-1) + + # Place the resized mask into the correct position + # in the final mask + padded_mask = ops.pad( + resized_mask, + ( + (y1, image_height - y1 - box_height), + (x1, image_width - x1 - box_width), + ), + ) + return padded_mask + + # Only consider bounding boxes for valid predictions + padded_mask = ops.cond( + ops.all([class_pred[i] != -1, box_height > 0, box_width > 0]), + do_resize, + lambda: ops.zeros( + (image_height, image_width), dtype=segmask_pred.dtype + ), ) - padded_mask = ops.squeeze(padded_mask, axis=-1) # Append the padded mask to the list padded_masks_list.append(padded_mask) @@ -438,7 +451,7 @@ def decode_segmentation_masks( ) # pick the mask prediction for the predicted class segmask_pred = ops.take_along_axis( - segmask_pred, class_pred[:, :, None, None, None] + 1, axis=-1 + segmask_pred, class_pred[:, :, None, None, None], axis=-1 ) final_masks = [] diff --git a/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py b/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py new file mode 100644 index 0000000000..77f130fcb3 --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py @@ -0,0 +1,566 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import tensorflow as tf + +from keras_cv.src import bounding_box +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 + +EPSILON = 1e-8 + + +@keras_cv_export("keras_cv.layers.NonMaxSuppression") +class NonMaxSuppression(keras.layers.Layer): + """A Keras layer that decodes predictions of an object detection model. + + Args: + bounding_box_format: The format of bounding boxes of input dataset. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box + formats. + from_logits: boolean, True means input score is logits, False means + confidence. + iou_threshold: a float value in the range [0, 1] representing the minimum + IoU threshold for two boxes to be considered same for suppression. + Defaults to 0.5. + confidence_threshold: a float value in the range [0, 1]. All boxes with + confidence below this value will be discarded, defaults to 0.5. + max_detections: the maximum detections to consider after nms is applied. A + large number may trigger significant memory overhead, defaults to 100. + """ # noqa: E501 + + def __init__( + self, + bounding_box_format, + from_logits, + iou_threshold=0.5, + confidence_threshold=0.5, + max_detections=100, + **kwargs, + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.from_logits = from_logits + self.iou_threshold = iou_threshold + self.confidence_threshold = confidence_threshold + self.max_detections = max_detections + self.built = True + + def call( + self, + box_prediction, + class_prediction, + mask_prediction=None, + images=None, + image_shape=None, + ): + """Accepts images and raw predictions, and returns bounding box + predictions. + + Args: + box_prediction: Dense Tensor of shape [batch, boxes, 4] in the + `bounding_box_format` specified in the constructor. + class_prediction: Dense Tensor of shape [batch, boxes, num_classes]. + mask_prediction: Dense Tensor of shape [batch, boxes, mask_height, + mask_width]. + """ + target_format = "yxyx" + if bounding_box.is_relative(self.bounding_box_format): + target_format = bounding_box.as_relative(target_format) + + box_prediction = bounding_box.convert_format( + box_prediction, + source=self.bounding_box_format, + target=target_format, + images=images, + image_shape=image_shape, + ) + if self.from_logits: + class_prediction = ops.sigmoid(class_prediction) + + confidence_prediction = ops.max(class_prediction, axis=-1) + + if not keras_3() or keras.backend.backend() == "tensorflow": + idx, valid_det = tf.image.non_max_suppression_padded( + box_prediction, + confidence_prediction, + max_output_size=self.max_detections, + iou_threshold=self.iou_threshold, + score_threshold=self.confidence_threshold, + pad_to_max_output_size=True, + sorted_input=False, + ) + elif keras.backend.backend() == "torch": + # Since TorchVision has a nice efficient NMS op, we might as well + # use it! + import torchvision + + batch_size = box_prediction.shape[0] + idx = ops.zeros((batch_size, self.max_detections)) + valid_det = ops.zeros((batch_size), "int32") + + for batch_idx in range(batch_size): + conf_mask = ( + confidence_prediction[batch_idx] > self.confidence_threshold + ) + conf_mask_idx = ops.squeeze(ops.nonzero(conf_mask), axis=0) + conf_i = confidence_prediction[batch_idx][conf_mask] + box_i = box_prediction[batch_idx][conf_mask] + + idx_i = torchvision.ops.nms( + box_i, conf_i, iou_threshold=self.iou_threshold + ) + + idx_i = conf_mask_idx[idx_i] + + num_boxes = idx_i.shape[0] + if num_boxes >= self.max_detections: + idx_i = idx_i[: self.max_detections] + num_boxes = self.max_detections + + valid_det[batch_idx] = ops.cast(ops.size(idx_i), "int32") + idx[batch_idx, :num_boxes] = idx_i + else: + idx, valid_det = non_max_suppression( + box_prediction, + confidence_prediction, + max_output_size=self.max_detections, + iou_threshold=self.iou_threshold, + score_threshold=self.confidence_threshold, + ) + + box_prediction = ops.take_along_axis( + box_prediction, ops.expand_dims(idx, axis=-1), axis=1 + ) + box_prediction = ops.reshape( + box_prediction, (-1, self.max_detections, 4) + ) + confidence_prediction = ops.take_along_axis( + confidence_prediction, idx, axis=1 + ) + class_prediction = ops.take_along_axis( + class_prediction, ops.expand_dims(idx, axis=-1), axis=1 + ) + + if mask_prediction is not None: + mask_prediction = ops.take_along_axis( + mask_prediction, idx[..., None, None, None], axis=1 + ) + + box_prediction = bounding_box.convert_format( + box_prediction, + source=target_format, + target=self.bounding_box_format, + images=images, + image_shape=image_shape, + ) + bounding_boxes = { + "boxes": box_prediction, + "confidence": confidence_prediction, + "classes": ops.argmax(class_prediction, axis=-1), + "num_detections": valid_det, + } + if mask_prediction is not None: + bounding_boxes["segmask"] = mask_prediction + + # this is required to comply with KerasCV bounding box format. + return bounding_box.mask_invalid_detections( + bounding_boxes, output_ragged=False + ) + + def get_config(self): + config = { + "bounding_box_format": self.bounding_box_format, + "from_logits": self.from_logits, + "iou_threshold": self.iou_threshold, + "confidence_threshold": self.confidence_threshold, + "max_detections": self.max_detections, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +def non_max_suppression( + boxes, + scores, + max_output_size, + iou_threshold=0.5, + score_threshold=0.0, + tile_size=512, +): + # Box format must be yxyx + """Non-maximum suppression. + Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458 + + Args: + boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4]. + Dimensions except the last two are batch dimensions. The last dimension + represents box coordinates in yxyx format. + scores: a tensor of rank 1 or higher with a shape of [..., num_boxes]. + max_output_size: a scalar integer tensor representing the maximum number + of boxes to be selected by non max suppression. + iou_threshold: a float representing the threshold for deciding whether boxes + overlap too much with respect to IoU (intersection over union). + score_threshold: a float representing the threshold for box scores. Boxes + with a score that is not larger than this threshold will be suppressed. + tile_size: an integer representing the number of boxes in a tile, i.e., + the maximum number of boxes per image that can be used to suppress other + boxes in parallel; larger tile_size means larger parallelism and + potentially more redundant work. + + Returns: + idx: a tensor with a shape of [..., num_boxes] representing the + indices selected by non-max suppression. The leading dimensions + are the batch dimensions of the input boxes. All numbers are within + [0, num_boxes). For each image (i.e., idx[i]), only the first num_valid[i] + indices (i.e., idx[i][:num_valid[i]]) are valid. + num_valid: a tensor of rank 0 or higher with a shape of [...] + representing the number of valid indices in idx. Its dimensions are the + batch dimensions of the input boxes. + """ # noqa: E501 + + def _sort_scores_and_boxes(scores, boxes): + """Sort boxes based their score from highest to lowest. + + Args: + scores: a tensor with a shape of [batch_size, num_boxes] representing + the scores of boxes. + boxes: a tensor with a shape of [batch_size, num_boxes, 4] representing + the boxes. + + Returns: + sorted_scores: a tensor with a shape of [batch_size, num_boxes] + representing the sorted scores. + sorted_boxes: a tensor representing the sorted boxes. + sorted_scores_indices: a tensor with a shape of [batch_size, num_boxes] + representing the index of the scores in a sorted descending order. + """ # noqa: E501 + with ops.name_scope("sort_scores_and_boxes"): + sorted_scores_indices = ops.flip( + ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1 + ) + sorted_scores = ops.take_along_axis( + scores, + sorted_scores_indices, + axis=1, + ) + sorted_boxes = ops.take_along_axis( + boxes, + ops.expand_dims(sorted_scores_indices, axis=-1), + axis=1, + ) + return sorted_scores, sorted_boxes, sorted_scores_indices + + batch_dims = ops.shape(boxes)[:-2] + num_boxes = boxes.shape[-2] + boxes = ops.reshape(boxes, [-1, num_boxes, 4]) + scores = ops.reshape(scores, [-1, num_boxes]) + batch_size = boxes.shape[0] + if score_threshold != float("-inf"): + with ops.name_scope("filter_by_score"): + score_mask = ops.cast(scores > score_threshold, scores.dtype) + scores *= score_mask + box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2) + boxes *= box_mask + + scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes) + + pad = ( + math.ceil(max(num_boxes, max_output_size) / tile_size) * tile_size + - num_boxes + ) + boxes = ops.pad(ops.cast(boxes, "float32"), [[0, 0], [0, pad], [0, 0]]) + scores = ops.pad(ops.cast(scores, "float32"), [[0, 0], [0, pad]]) + num_boxes_after_padding = num_boxes + pad + num_iterations = num_boxes_after_padding // tile_size + + def _loop_cond(unused_boxes, unused_threshold, output_size, idx): + return ops.logical_and( + ops.min(output_size) < ops.cast(max_output_size, "int32"), + ops.cast(idx, "int32") < num_iterations, + ) + + def suppression_loop_body(boxes, iou_threshold, output_size, idx): + return _suppression_loop_body( + boxes, iou_threshold, output_size, idx, tile_size + ) + + selected_boxes, _, output_size, _ = ops.while_loop( + _loop_cond, + suppression_loop_body, + [ + boxes, + iou_threshold, + ops.zeros([batch_size], "int32"), + ops.array(0), + ], + ) + num_valid = ops.minimum(output_size, max_output_size) + idx = num_boxes_after_padding - ops.cast( + ops.top_k( + ops.cast(ops.any(selected_boxes > 0, [2]), "int32") + * ops.cast( + ops.expand_dims(ops.arange(num_boxes_after_padding, 0, -1), 0), + "int32", + ), + max_output_size, + )[0], + "int32", + ) + idx = ops.minimum(idx, num_boxes - 1) + + index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32") + take_along_axis_idx = ops.reshape( + idx + ops.expand_dims(index_offsets, 1), [-1] + ) + + # TODO(ianstenbit): Fix bug in tfnp.take_along_axis that causes this hack. + # (This will be removed anyway when we use built-in NMS for TF.) + if keras_3() and keras.backend.backend() != "tensorflow": + idx = ops.take_along_axis( + ops.reshape(sorted_indices, [-1]), take_along_axis_idx + ) + else: + import tensorflow as tf + + idx = tf.gather(ops.reshape(sorted_indices, [-1]), take_along_axis_idx) + idx = ops.reshape(idx, [batch_size, -1]) + + invalid_index = ops.zeros([batch_size, max_output_size], dtype="int32") + idx_index = ops.cast( + ops.expand_dims(ops.arange(max_output_size), 0), "int32" + ) + num_valid_expanded = ops.expand_dims(num_valid, 1) + idx = ops.where(idx_index < num_valid_expanded, idx, invalid_index) + + num_valid = ops.reshape(num_valid, batch_dims) + return idx, num_valid + + +def _bbox_overlap(boxes_a, boxes_b): + """Calculates the overlap (iou - intersection over union) between boxes_a and boxes_b. + + Args: + boxes_a: a tensor with a shape of [batch_size, N, 4]. N is the number of + boxes per image. The last dimension is the pixel coordinates in + [ymin, xmin, ymax, xmax] form. + boxes_b: a tensor with a shape of [batch_size, M, 4]. M is the number of + boxes. The last dimension is the pixel coordinates in + [ymin, xmin, ymax, xmax] form. + + Returns: + intersection_over_union: a tensor with as a shape of [batch_size, N, M], + representing the ratio of intersection area over union area (IoU) between + two boxes + """ # noqa: E501 + with ops.name_scope("bbox_overlap"): + if len(boxes_a.shape) == 4: + boxes_a = ops.squeeze(boxes_a, axis=0) + a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2) + b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2) + + # Calculates the intersection area. + i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1])) + i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1])) + i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1])) + i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1])) + i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum( + (i_ymax - i_ymin), 0 + ) + + # Calculates the union area. + a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min) + b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min) + + # Adds a small epsilon to avoid divide-by-zero. + u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON + + intersection_over_union = i_area / u_area + + return intersection_over_union + + +def _self_suppression(iou, _, iou_sum, iou_threshold): + """Suppress boxes in the same tile. + + Compute boxes that cannot be suppressed by others (i.e., + can_suppress_others), and then use them to suppress boxes in the same tile. + + Args: + iou: a tensor of shape [batch_size, num_boxes_with_padding] representing + intersection over union. + iou_sum: a scalar tensor. + iou_threshold: a scalar tensor. + + Returns: + iou_suppressed: a tensor of shape [batch_size, num_boxes_with_padding]. + iou_diff: a scalar tensor representing whether any box is supressed in + this step. + iou_sum_new: a scalar tensor of shape [batch_size] that represents + the iou sum after suppression. + iou_threshold: a scalar tensor. + """ # noqa: E501 + batch_size = ops.shape(iou)[0] + can_suppress_others = ops.cast( + ops.reshape(ops.max(iou, 1) < iou_threshold, [batch_size, -1, 1]), + iou.dtype, + ) + iou_after_suppression = ( + ops.reshape( + ops.cast( + ops.max(can_suppress_others * iou, 1) < iou_threshold, iou.dtype + ), + [batch_size, -1, 1], + ) + * iou + ) + iou_sum_new = ops.sum(iou_after_suppression, [1, 2]) + return [ + iou_after_suppression, + ops.any(iou_sum - iou_sum_new > iou_threshold), + iou_sum_new, + iou_threshold, + ] + + +def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size): + """Suppress boxes between different tiles. + + Args: + boxes: a tensor of shape [batch_size, num_boxes_with_padding, 4] + box_slice: a tensor of shape [batch_size, tile_size, 4] + iou_threshold: a scalar tensor + inner_idx: a scalar tensor representing the tile index of the tile + that is used to supress box_slice + tile_size: an integer representing the number of boxes in a tile + + Returns: + boxes: unchanged boxes as input + box_slice_after_suppression: box_slice after suppression + iou_threshold: unchanged + """ + slice_index = ops.expand_dims( + ops.expand_dims( + ops.cast( + ops.linspace( + inner_idx * tile_size, + (inner_idx + 1) * tile_size - 1, + tile_size, + ), + "int32", + ), + axis=0, + ), + axis=-1, + ) + new_slice = ops.expand_dims( + ops.take_along_axis(boxes, slice_index, axis=1), 0 + ) + iou = _bbox_overlap(new_slice, box_slice) + box_slice_after_suppression = ( + ops.expand_dims( + ops.cast(ops.all(iou < iou_threshold, [1]), box_slice.dtype), 2 + ) + * box_slice + ) + return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1 + + +def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size): + """Process boxes in the range [idx*tile_size, (idx+1)*tile_size). + + Args: + boxes: a tensor with a shape of [batch_size, anchors, 4]. + iou_threshold: a float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. + output_size: an int32 tensor of size [batch_size]. Representing the number + of selected boxes for each batch. + idx: an integer scalar representing induction variable. + tile_size: an integer representing the number of boxes in a tile + + Returns: + boxes: updated boxes. + iou_threshold: pass down iou_threshold to the next iteration. + output_size: the updated output_size. + idx: the updated induction variable. + """ # noqa: E501 + with ops.name_scope("suppression_loop_body"): + num_tiles = boxes.shape[1] // tile_size + batch_size = boxes.shape[0] + + def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx): + return _cross_suppression( + boxes, box_slice, iou_threshold, inner_idx, tile_size + ) + + # Iterates over tiles that can possibly suppress the current tile. + slice_index = ops.expand_dims( + ops.expand_dims( + ops.cast( + ops.linspace( + idx * tile_size, (idx + 1) * tile_size - 1, tile_size + ), + "int32", + ), + axis=0, + ), + axis=-1, + ) + box_slice = ops.take_along_axis(boxes, slice_index, axis=1) + _, box_slice, _, _ = ops.while_loop( + lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, + cross_suppression_func, + [boxes, box_slice, iou_threshold, ops.array(0)], + ) + + # Iterates over the current tile to compute self-suppression. + iou = _bbox_overlap(box_slice, box_slice) + mask = ops.expand_dims( + ops.reshape(ops.arange(tile_size), [1, -1]) + > ops.reshape(ops.arange(tile_size), [-1, 1]), + 0, + ) + iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype) + suppressed_iou, _, _, _ = ops.while_loop( + lambda _iou, loop_condition, _iou_sum, _: loop_condition, + _self_suppression, + [iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold], + ) + suppressed_box = ops.sum(suppressed_iou, 1) > 0 + box_slice *= ops.expand_dims( + 1.0 - ops.cast(suppressed_box, box_slice.dtype), 2 + ) + + # Uses box_slice to update the input boxes. + mask = ops.reshape( + ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype), + [1, -1, 1, 1], + ) + boxes = ops.tile( + ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] + ) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * ( + 1 - mask + ) + boxes = ops.reshape(boxes, [batch_size, -1, 4]) + + # Updates output_size. + output_size += ops.cast( + ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32" + ) + return boxes, iou_threshold, output_size, idx + 1 \ No newline at end of file diff --git a/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py new file mode 100644 index 0000000000..7a3707748b --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py @@ -0,0 +1,318 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.src import bounding_box +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.bounding_box import iou +from keras_cv.src.layers.object_detection import box_matcher +from keras_cv.src.layers.object_detection import sampling +from keras_cv.src.utils import target_gather + + +@keras.utils.register_keras_serializable(package="keras_cv") +class ROISampler(keras.layers.Layer): + """ + Sample ROIs for loss related calculation. + + With proposals (ROIs) and ground truth, it performs the following: + 1) compute IOU similarity matrix + 2) match each proposal to ground truth box based on IOU + 3) samples positive matches and negative matches and return + + `append_gt_boxes` augments proposals with ground truth boxes. This is + useful in 2 stage detection networks during initialization where the + 1st stage often cannot produce good proposals for 2nd stage. Setting it to + True will allow it to generate more reasonable proposals at the beginning. + + `background_class` allow users to set the labels for background proposals. + Default is 0, where users need to manually shift the incoming `gt_classes` + if its range is [0, num_classes). + + Args: + roi_bounding_box_format: The format of roi bounding boxes. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + gt_bounding_box_format: The format of ground truth bounding boxes. + roi_matcher: a `BoxMatcher` object that matches proposals with ground + truth boxes. The positive match must be 1 and negative match must be -1. + Such assumption is not being validated here. + positive_fraction: the positive ratio w.r.t `num_sampled_rois`, defaults + to 0.25. + background_class: the background class which is used to map returned the + sampled ground truth which is classified as background. + num_sampled_rois: the number of sampled proposals per image for + further (loss) calculation, defaults to 256. + append_gt_boxes: boolean, whether gt_boxes will be appended to rois + before sample the rois, defaults to True. + mask_shape: The shape of segmentation masks used for training, + defaults to (14,14). + """ # noqa: E501 + + def __init__( + self, + roi_bounding_box_format: str, + gt_bounding_box_format: str, + roi_matcher: box_matcher.BoxMatcher, + positive_fraction: float = 0.25, + background_class: int = 0, + num_sampled_rois: int = 256, + append_gt_boxes: bool = True, + mask_shape=(14, 14), + **kwargs, + ): + super().__init__(**kwargs) + self.roi_bounding_box_format = roi_bounding_box_format + self.gt_bounding_box_format = gt_bounding_box_format + self.roi_matcher = roi_matcher + self.positive_fraction = positive_fraction + self.background_class = background_class + self.num_sampled_rois = num_sampled_rois + self.append_gt_boxes = append_gt_boxes + self.mask_shape = mask_shape + self.seed_generator = keras.random.SeedGenerator() + self.built = True + # for debugging. + self._positives = keras.metrics.Mean() + self._negatives = keras.metrics.Mean() + + def call(self, rois, gt_boxes, gt_classes, gt_masks=None): + """ + Args: + rois: [batch_size, num_rois, 4] + gt_boxes: [batch_size, num_gt, 4] + gt_classes: [batch_size, num_gt, 1] + gt_masks: [batch_size, num_gt, height, width] + Returns: + sampled_rois: [batch_size, num_sampled_rois, 4] + sampled_gt_boxes: [batch_size, num_sampled_rois, 4] + sampled_box_weights: [batch_size, num_sampled_rois, 1] + sampled_gt_classes: [batch_size, num_sampled_rois, 1] + sampled_class_weights: [batch_size, num_sampled_rois, 1] + sampled_gt_masks: + [batch_size, num_sampled_rois, mask_height, mask_width] + sampled_mask_weights: [batch_size, num_sampled_rois, 1] + """ + rois = bounding_box.convert_format( + rois, source=self.roi_bounding_box_format, target="yxyx" + ) + gt_boxes = bounding_box.convert_format( + gt_boxes, source=self.gt_bounding_box_format, target="yxyx" + ) + if self.append_gt_boxes: + # num_rois += num_gt + rois = ops.concatenate([rois, gt_boxes], axis=1) + num_rois = ops.shape(rois)[1] + if num_rois is None: + raise ValueError( + f"`rois` must have static shape, got {ops.shape(rois)}" + ) + if num_rois < self.num_sampled_rois: + raise ValueError( + "num_rois must be less than `num_sampled_rois` " + f"({self.num_sampled_rois}), got {num_rois}" + ) + # [batch_size, num_rois, num_gt] + similarity_mat = iou.compute_iou( + rois, gt_boxes, bounding_box_format="yxyx", use_masking=True + ) + # [batch_size, num_rois] | [batch_size, num_rois] + matched_gt_cols, matched_vals = self.roi_matcher(similarity_mat) + # [batch_size, num_rois] + positive_matches = ops.equal(matched_vals, 1) + negative_matches = ops.equal(matched_vals, -1) + self._positives.update_state( + ops.sum(ops.cast(positive_matches, "float32"), axis=-1) + ) + self._negatives.update_state( + ops.sum(ops.cast(negative_matches, "float32"), axis=-1) + ) + # [batch_size, num_rois, 1] + background_mask = ops.expand_dims( + ops.logical_not(positive_matches), axis=-1 + ) + # [batch_size, num_rois, 1] + matched_gt_classes = target_gather._target_gather( + gt_classes, matched_gt_cols + ) + # also set all background matches to `background_class` + matched_gt_classes = ops.where( + background_mask, + ops.cast( + self.background_class * ops.ones_like(matched_gt_classes), + gt_classes.dtype, + ), + matched_gt_classes, + ) + # [batch_size, num_rois, 4] + matched_gt_boxes = target_gather._target_gather( + gt_boxes, matched_gt_cols + ) + encoded_matched_gt_boxes = bounding_box._encode_box_to_deltas( + anchors=rois, + boxes=matched_gt_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=[0.1, 0.1, 0.2, 0.2], + ) + # also set all background matches to 0 coordinates + encoded_matched_gt_boxes = ops.where( + background_mask, + ops.zeros_like(matched_gt_boxes), + encoded_matched_gt_boxes, + ) + # [batch_size, num_rois] + sampled_indicators = sampling.balanced_sample( + positive_matches, + negative_matches, + self.num_sampled_rois, + self.positive_fraction, + seed=self.seed_generator, + ) + # [batch_size, num_sampled_rois] in the range of [0, num_rois) + sampled_indicators, sampled_indices = ops.top_k( + sampled_indicators, k=self.num_sampled_rois, sorted=True + ) + # [batch_size, num_sampled_rois, 4] + sampled_rois = target_gather._target_gather(rois, sampled_indices) + # [batch_size, num_sampled_rois, 4] + sampled_gt_boxes = target_gather._target_gather( + encoded_matched_gt_boxes, sampled_indices + ) + # [batch_size, num_sampled_rois, 1] + sampled_gt_classes = target_gather._target_gather( + matched_gt_classes, sampled_indices + ) + # [batch_size, num_sampled_rois, 1] + # all negative samples will be ignored in regression + sampled_box_weights = target_gather._target_gather( + ops.cast(positive_matches[..., None], gt_boxes.dtype), + sampled_indices, + ) + # [batch_size, num_sampled_rois, 1] + sampled_indicators = sampled_indicators[..., None] + sampled_class_weights = ops.cast(sampled_indicators, gt_classes.dtype) + + if gt_masks is not None: + sampled_gt_cols = target_gather._target_gather( + matched_gt_cols[:, :, None], sampled_indices + ) + + # [batch_size, num_sampled_rois, height, width] + cropped_and_resized_masks = crop_and_resize( + ops.expand_dims(gt_masks, axis=-1), + bounding_boxes=sampled_rois, + target_size=self.mask_shape, + ) + cropped_and_resized_masks = ops.squeeze( + cropped_and_resized_masks, axis=-1 + ) + + sampled_gt_masks = ops.equal( + cropped_and_resized_masks, sampled_gt_cols[..., None] + 1 + ) + sampled_gt_masks = ops.cast(sampled_gt_masks, "float32") + + # Mask weights: 1 for positive samples, 0 for background + sampled_mask_weights = sampled_box_weights + + sampled_data = ( + sampled_rois, + sampled_gt_boxes, + sampled_box_weights, + sampled_gt_classes, + sampled_class_weights, + ) + if gt_masks is not None: + sampled_data = sampled_data + ( + sampled_gt_masks, + sampled_mask_weights, + ) + return sampled_data + + def get_config(self): + config = super().get_config() + config["roi_bounding_box_format"] = self.roi_bounding_box_format + config["gt_bounding_box_format"] = self.gt_bounding_box_format + config["positive_fraction"] = self.positive_fraction + config["background_class"] = self.background_class + config["num_sampled_rois"] = self.num_sampled_rois + config["append_gt_boxes"] = self.append_gt_boxes + config["mask_shape"] = self.mask_shape + config["roi_matcher"] = self.roi_matcher.get_config() + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + roi_matcher_config = config.pop("roi_matcher") + roi_matcher = box_matcher.BoxMatcher(**roi_matcher_config) + return cls(roi_matcher=roi_matcher, **config) + + +def crop_and_resize(images, bounding_boxes, target_size): + """ + A utility function to crop and resize bounding boxes from + images to a given size. + + `bounding_boxes` is expected to be in yxyx format. + """ + + num_images, num_boxes = ops.shape(bounding_boxes)[:2] + bounding_boxes = ops.cast(bounding_boxes, "int32") + channels = ops.shape(images)[3] + + cropped_and_resized_images = [] + for image_idx in range(num_images): + for box_idx in range(num_boxes): + y1, x1, y2, x2 = ops.unstack(bounding_boxes[image_idx, box_idx]) + # crop to the bounding box + slice_y = ops.maximum(y1, 0) + slice_x = ops.maximum(x1, 0) + cropped_image = ops.slice( + images[image_idx], + (slice_y, slice_x, 0), + (y2 - slice_y, x2 - slice_x, channels), + ) + # pad if the bounding box goes beyond the image + pad_y = -ops.minimum(y1, 0) + pad_x = -ops.minimum(x1, 0) + cropped_image = ops.pad( + cropped_image, + ( + ( + pad_y, + ops.maximum(y2 - y1, 1) + - ops.shape(cropped_image)[0] + - pad_y, + ), + ( + pad_x, + ops.maximum(x2 - x1, 1) + - ops.shape(cropped_image)[1] + - pad_x, + ), + (0, 0), + ), + ) + # resize to the target size + resized_image = ops.image.resize(cropped_image, target_size) + cropped_and_resized_images.append(resized_image) + + cropped_and_resized_images = ops.stack(cropped_and_resized_images, axis=0) + + target_shape = (num_images, num_boxes, *target_size, channels) + cropped_and_resized_images = ops.reshape( + cropped_and_resized_images, target_shape + ) + return cropped_and_resized_images \ No newline at end of file