Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/launch and param #5

Merged
merged 10 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Running detic instance segmentation with ROS 2 topic.
## How to use.

In order to build this package, just run.
This packages support ROS 2 humble.

```
rosdep install -iry --from-paths .
Expand All @@ -29,10 +30,11 @@ You can see detection results in `/detic_result/image` topic with sensor_msgs/ms
- [x] Visualize segmentation result.
- [x] Publish object class.
- [x] Publish object score.
- [ ] Add launch file.
- [ ] Add config file for setting detection width / detic model type / vocaburary etc...
- [ ] Publish object mask.
- [ ] Inference with GPU.
- [x] Add launch file.
- [ ] Add parameter for setting detection width / detic model type / vocaburary etc...
- [x] Publish object mask.
- [x] Inference with GPU.
- [ ] Add test case.

## Limitation
Custom vocabulary will not be supported because of onnx model used in this package does not support it.
84 changes: 52 additions & 32 deletions detic_onnx_ros2/detic_onnx_ros2/detic_onnx_ros2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, List, Dict
from typing import Any, List, Dict, Tuple
import requests
import onnxruntime
import PIL.Image
Expand All @@ -11,29 +11,38 @@
from rclpy.node import Node
from sensor_msgs.msg import Image

from detic_onnx_ros2_msg.msg import SegmentationInfo
from detic_onnx_ros2_msg.msg import (
SegmentationInfo,
Segmentation,
Polygon,
PointOnImage,
)

from cv_bridge import CvBridge
from detic_onnx_ros2.imagenet_21k import IN21K_CATEGORIES
from detic_onnx_ros2.lvis import LVIS_CATEGORIES as LVIS_V1_CATEGORIES
from ament_index_python import get_package_share_directory
from detic_onnx_ros2.color import random_color, color_brightness
import copy
import time


class DeticNode(Node):
def __init__(self):
super().__init__("detic_node")
self.declare_parameter("detection_width", 800)
self.detection_width: int = self.get_parameter("detection_width").value
self.weight_and_model = self.download_onnx(
"Detic_C2_SwinB_896_4x_IN-21K+COCO_lvis_op16.onnx"
)
self.session = onnxruntime.InferenceSession(
self.weight_and_model,
providers=["CPUExecutionProvider"], # "CUDAExecutionProvider"],
)
self.publisher = self.create_publisher(Image, "detic_result/image", 10)
self.image_publisher = self.create_publisher(
Image, self.get_name() + "/detic_result/image", 10
)
self.segmentation_publisher = self.create_publisher(
SegmentationInfo, "segmentationinfo", 10
SegmentationInfo, self.get_name() + "/detic_result/segmentation_info", 10
)
self.subscription = self.create_subscription(
Image,
Expand All @@ -42,7 +51,6 @@ def __init__(self):
10,
)
self.bridge = CvBridge()
self.segmentationinfo = SegmentationInfo()

def download_onnx(
self,
Expand Down Expand Up @@ -76,8 +84,8 @@ def get_in21k_meta_v1(self) -> Dict[str, List[str]]:

def draw_predictions(
self, image: np.ndarray, detection_results: Any, vocabulary: str
) -> np.ndarray:

) -> Tuple[np.ndarray, List[Segmentation]]:
segmentations: List[Segmentation] = []
width = image.shape[1]
height = image.shape[0]

Expand All @@ -91,8 +99,10 @@ def draw_predictions(
if vocabulary == "lvis"
else self.get_in21k_meta_v1()
)["thing_classes"]
labels = [class_names[i] for i in classes]
labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
object_labels = [class_names[i] for i in classes]
labels = [
"{} {:.0f}%".format(l, s * 100) for l, s in zip(object_labels, scores)
]

num_instances = len(boxes)

Expand All @@ -111,6 +121,7 @@ def draw_predictions(
default_font_size = int(max(np.sqrt(height * width) // 90, 10))

for i in range(num_instances):
segmentation: Segmentation = Segmentation()
color = assigned_colors[i]
color = (int(color[0]), int(color[1]), int(color[2]))
image_b = image.copy()
Expand All @@ -124,12 +135,26 @@ def draw_predictions(
color=color,
thickness=default_font_size // 4,
)
segmentation.object_class = object_labels[i]
segmentation.score = float(scores[i])
segmentation.bounding_box.xmin = int(min(x0, x1))
segmentation.bounding_box.xmax = int(max(x0, x1))
segmentation.bounding_box.ymin = int(min(y0, y1))
segmentation.bounding_box.ymax = int(max(y0, y1))

# draw segment
polygons = self.mask_to_polygons(masks[i])
for points in polygons:
polygon = Polygon()
points = np.array(points).reshape((1, -1, 2)).astype(np.int32)
for i in range(points[0].shape[0]):
point_on_image = PointOnImage()
point_on_image.x = int(points[0][i][0])
point_on_image.y = int(points[0][i][1])
polygon.points.append(point_on_image)
cv2.fillPoly(image_b, pts=[points], color=color)
segmentation.polygons.append(polygon)
segmentations.append(segmentation)

image = cv2.addWeighted(image, 0.5, image_b, 0.5, 0)

Expand Down Expand Up @@ -176,7 +201,7 @@ def draw_predictions(
lineType=cv2.LINE_AA,
)

return image
return image, segmentations

def mask_to_polygons(self, mask: np.ndarray) -> List[Any]:
# cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
Expand All @@ -199,11 +224,11 @@ def mask_to_polygons(self, mask: np.ndarray) -> List[Any]:
# would be to first +0.5 and then dilate the returned polygon by 0.5.
return [x + 0.5 for x in res if len(x) >= 6]

def preprocess(self, image: np.ndarray, detection_width: int = 800) -> np.ndarray:
def preprocess(self, image: np.ndarray) -> np.ndarray:
height, width, _ = image.shape
image = image[:, :, ::-1] # BGR -> RGB
size = detection_width
max_size = detection_width
size = self.detection_width
max_size = self.detection_width
scale = size / min(height, width)
if height < width:
oh, ow = size, scale * width
Expand Down Expand Up @@ -237,13 +262,20 @@ def image_callback(self, msg):
image = self.preprocess(image=input_image)
input_height = image.shape[2]
input_width = image.shape[3]
inference_start_time = time.perf_counter()
boxes, scores, classes, masks = self.session.run(
None,
{
"img": image,
"im_hw": np.array([input_height, input_width]).astype(np.int64),
},
)
inference_end_time = time.perf_counter()
self.get_logger().info(
"Inference takes "
+ str(inference_end_time - inference_start_time)
+ " [sec]"
)
draw_mask = masks
masks = masks.astype(np.uint8)
draw_classes = classes
Expand All @@ -258,37 +290,25 @@ def image_callback(self, msg):
boxes = boxes[sorted_idxs]
labels = [labels[k] for k in sorted_idxs]
masks = [masks[idx] for idx in sorted_idxs]
# print(f"mask data type : {type(masks)}")
# print(f"mask data shape : {masks[0].shape}")
# print(f"mask data : {masks}")
scores = scores.astype(np.float32)
segMsg = self.bridge.cv2_to_imgmsg(masks[0], "8UC1")
# segMsg = []
# for i in masks:
# segMsg.append(self.bridge.cv2_to_imgmsg(i, 'mono8'))

self.segmentationinfo.header.stamp = self.get_clock().now().to_msg()
self.segmentationinfo.detected_classes = labels
# self.segmentationinfo.scores = scores
self.segmentationinfo.segmentation = segMsg

self.segmentation_publisher.publish(self.segmentationinfo)

detection_results = {
"boxes": draw_boxes,
"scores": draw_scores,
"classes": draw_classes,
"masks": draw_mask,
}
visualization = self.draw_predictions(
visualization, segmentations = self.draw_predictions(
cv2.cvtColor(
cv2.resize(input_image, (input_width, input_height)), cv2.COLOR_BGR2RGB
),
detection_results,
"lvis",
)
imgMsg = self.bridge.cv2_to_imgmsg(visualization, "bgr8")
self.publisher.publish(imgMsg)
segmentation_info = SegmentationInfo()
segmentation_info.header = msg.header
segmentation_info.segmentations = segmentations
self.segmentation_publisher.publish(segmentation_info)
self.image_publisher.publish(self.bridge.cv2_to_imgmsg(visualization, "bgr8"))


def main(args=None):
Expand Down
8 changes: 8 additions & 0 deletions detic_onnx_ros2/launch/detic_onnx_ros2.launch.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<launch>
<arg name="input_topic" default="image_raw"/>
<arg name="detection_width" default="800"/>
<node name="detic_onnx_ros2_node" pkg="detic_onnx_ros2" exec="detic_onnx_ros2_node">
<remap from="image_raw" to="$(var input_topic)"/>
<param name="detection_width" value="$(var detection_width)"/>
</node>
</launch>
2 changes: 2 additions & 0 deletions detic_onnx_ros2/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
<maintainer email="[email protected]">ubuntu</maintainer>
<license>Apache 2.0</license>

<depend>detic_onnx_ros2_msg</depend>
<depend>python3-onnxruntime-gpu-pip</depend>
<depend>python3-pil</depend>
<depend>python3-requests</depend>
<depend>launah_xml</depend>

<test_depend>ament_copyright</test_depend>
<test_depend>ament_pep257</test_depend>
Expand Down
3 changes: 3 additions & 0 deletions detic_onnx_ros2/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from setuptools import setup
import os
from glob import glob

package_name = "detic_onnx_ros2"

Expand All @@ -9,6 +11,7 @@
data_files=[
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
("share/" + package_name, ["package.xml"]),
(os.path.join("share", package_name), glob("./launch/*.launch.xml")),
],
install_requires=["setuptools"],
zip_safe=True,
Expand Down
7 changes: 4 additions & 3 deletions detic_onnx_ros2_msg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(sensor_msgs REQUIRED)
# uncomment the following section in order to fill in
# further dependencies manually.
# find_package(<dependency> REQUIRED)

rosidl_generate_interfaces(${PROJECT_NAME}
"msg/PointOnImage.msg"
"msg/Polygon.msg"
"msg/BoundingBox.msg"
"msg/Segmentation.msg"
"msg/SegmentationInfo.msg"
DEPENDENCIES std_msgs sensor_msgs
)
Expand Down
2 changes: 0 additions & 2 deletions detic_onnx_ros2_msg/msg/BoundingBox.msg
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@ int64 xmin
int64 ymin
int64 xmax
int64 ymax
int16 id
string Class
2 changes: 0 additions & 2 deletions detic_onnx_ros2_msg/msg/BoundingBoxes.msg

This file was deleted.

Empty file.
2 changes: 2 additions & 0 deletions detic_onnx_ros2_msg/msg/PointOnImage.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
uint32 x
uint32 y
1 change: 1 addition & 0 deletions detic_onnx_ros2_msg/msg/Polygon.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
detic_onnx_ros2_msg/PointOnImage[] points
4 changes: 4 additions & 0 deletions detic_onnx_ros2_msg/msg/Segmentation.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
string object_class
float32 score
detic_onnx_ros2_msg/BoundingBox bounding_box
detic_onnx_ros2_msg/Polygon[] polygons
4 changes: 1 addition & 3 deletions detic_onnx_ros2_msg/msg/SegmentationInfo.msg
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
std_msgs/Header header
string[] detected_classes
float32[] scores
sensor_msgs/Image segmentation
detic_onnx_ros2_msg/Segmentation[] segmentations
Loading