Skip to content

Commit

Permalink
Port Mask R-CNN to Keras3 (#2483)
Browse files Browse the repository at this point in the history
* Port Faster R-CNN to Keras3

* Port Mask  R-CNN to Keras3

* added the processing of mask predictions

* "Revert changes "#

* added faster_rcnn as backbone for mask_rcnn

* add mask predictions in faster_rcnn

* remove multiple  variable declaration  in roi_sampler

* removing changes to nms.py and roi_sampler.py

* add newline at end to revert

* removed extraneous whitespace

* removing changes to fasterrcnn.py

* Move files to maskrcnn folder and addressed all the required changes

* Restructuring and Backbone implementation changes

* address format issues

* adding test cases

* adding maskrcnn into workflow

* Fix order of decorators and  jax integer dtype error

* Fix format

* Fix tests for GPU runs

* Revert keras version to 3.3.3 in build system

* Avoid TimeDistributed layers to fix for keras 3.3.3 and Acknowledge randomness for test
  • Loading branch information
laxmareddyp authored Sep 26, 2024
1 parent 4e0855f commit 4dfa66f
Show file tree
Hide file tree
Showing 14 changed files with 2,874 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ jobs:
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection/mask_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
--durations 0
Expand Down
2 changes: 2 additions & 0 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ then
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection/mask_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
keras_cv/src/models/feature_extractor/clip \
Expand All @@ -88,6 +89,7 @@ else
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection/mask_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
keras_cv/src/models/feature_extractor/clip \
Expand Down
2 changes: 2 additions & 0 deletions keras_cv/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras_cv.api.models import classification
from keras_cv.api.models import faster_rcnn
from keras_cv.api.models import feature_extractor
from keras_cv.api.models import mask_rcnn
from keras_cv.api.models import object_detection
from keras_cv.api.models import retinanet
from keras_cv.api.models import segmentation
Expand Down Expand Up @@ -209,6 +210,7 @@
from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
FasterRCNN,
)
from keras_cv.src.models.object_detection.mask_rcnn.mask_rcnn import MaskRCNN
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_backbone import (
YOLOV8Backbone,
Expand Down
7 changes: 7 additions & 0 deletions keras_cv/api/models/mask_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""DO NOT EDIT.
This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras_cv.src.models.object_detection.mask_rcnn.mask_head import MaskHead
1 change: 1 addition & 0 deletions keras_cv/api/models/object_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
FasterRCNN,
)
from keras_cv.src.models.object_detection.mask_rcnn.mask_rcnn import MaskRCNN
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_detector import (
YOLOV8Detector,
Expand Down
17 changes: 17 additions & 0 deletions keras_cv/src/models/object_detection/mask_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from keras_cv.src.models.object_detection.mask_rcnn.mask_head import MaskHead
from keras_cv.src.models.object_detection.mask_rcnn.mask_rcnn import MaskRCNN
117 changes: 117 additions & 0 deletions keras_cv/src/models/object_detection/mask_rcnn/mask_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.src.api_export import keras_cv_export
from keras_cv.src.backend import keras


@keras_cv_export(
"keras_cv.models.mask_rcnn.MaskHead",
package="keras_cv.models.mask_rcnn",
)
class MaskHead(keras.layers.Layer):
"""A Keras layer implementing the R-CNN Mask Head.
The architecture is adopted from Matterport's Mask R-CNN implementation
https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/model.py.
Args:
num_classes: The number of object classes that are being detected,
excluding the background class.
stackwise_num_conv_filters: (Optional) a list of integers specifying
the number of filters for each convolutional layer. Defaults
to [256, 256].
num_deconv_filters: (Optional) the number of filters to use in the
upsampling convolutional layer. Defaults to 256.
"""

def __init__(
self,
num_classes,
stackwise_num_conv_filters=[256, 256],
num_deconv_filters=256,
**kwargs,
):
super().__init__(**kwargs)
self.num_classes = num_classes
self.stackwise_num_conv_filters = stackwise_num_conv_filters
self.num_deconv_filters = num_deconv_filters
self.layers = []
for num_filters in stackwise_num_conv_filters:
conv = keras.layers.Conv2D(
filters=num_filters,
kernel_size=3,
padding="same",
)
batchnorm = keras.layers.BatchNormalization()
activation = keras.layers.Activation("relu")
self.layers.extend([conv, batchnorm, activation])

self.deconv = keras.layers.Conv2DTranspose(
num_deconv_filters,
kernel_size=2,
strides=2,
activation="relu",
padding="valid",
)
# we do not use a final sigmoid activation, since we use
# from_logits=True during training
self.segmentation_mask_output = keras.layers.Conv2D(
num_classes + 1,
kernel_size=1,
strides=1,
activation="linear",
)

def call(self, feature_map, training=False):
# reshape batch and ROI axes into one axis to obtain a suitable
# shape for conv layers
num_rois = keras.ops.shape(feature_map)[1]
x = keras.ops.reshape(feature_map, (-1, *feature_map.shape[2:]))
for layer in self.layers:
x = layer(x, training=training)
x = self.deconv(x)
segmentation_mask = self.segmentation_mask_output(x)
segmentation_mask = keras.ops.reshape(
segmentation_mask, (-1, num_rois, *segmentation_mask.shape[1:])
)
return segmentation_mask

def build(self, input_shape):
if input_shape[0] is None or input_shape[1] is None:
intermediate_shape = (None, *input_shape[2:])
else:
intermediate_shape = (
input_shape[0] * input_shape[1],
*input_shape[2:],
)
for idx, num_filters in enumerate(self.stackwise_num_conv_filters):
self.layers[idx * 3].build(intermediate_shape)
intermediate_shape = tuple(intermediate_shape[:-1]) + (num_filters,)
self.layers[idx * 3 + 1].build(intermediate_shape)
self.deconv.build(intermediate_shape)
intermediate_shape = tuple(intermediate_shape[:-3]) + (
intermediate_shape[-3] * 2,
intermediate_shape[-2] * 2,
self.num_deconv_filters,
)
self.segmentation_mask_output.build(intermediate_shape)
self.built = True

def get_config(self):
config = super().get_config()
config["num_classes"] = self.num_classes
config["stackwise_num_conv_filters"] = self.stackwise_num_conv_filters
config["num_deconv_filters"] = self.num_deconv_filters
return config
56 changes: 56 additions & 0 deletions keras_cv/src/models/object_detection/mask_rcnn/mask_head_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from absl.testing import parameterized

from keras_cv.src.backend import ops
from keras_cv.src.backend.config import keras_3
from keras_cv.src.models.object_detection.mask_rcnn import MaskHead
from keras_cv.src.tests.test_case import TestCase


class RCNNHeadTest(TestCase):
@parameterized.parameters(
(2, 256, 20, 7, 256),
(1, 512, 80, 14, 512),
)
@pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2")
def test_mask_head_output_shapes(
self,
batch_size,
num_rois,
num_classes,
roi_align_target_size,
num_filters,
):
layer = MaskHead(num_classes)

inputs = ops.ones(
shape=(
batch_size,
num_rois,
roi_align_target_size,
roi_align_target_size,
num_filters,
)
)
outputs = layer(inputs)

mask_size = roi_align_target_size * 2

self.assertEqual(
(batch_size, num_rois, mask_size, mask_size, num_classes + 1),
outputs.shape,
)
Loading

0 comments on commit 4dfa66f

Please sign in to comment.