-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Discover detectors function added * changed dir from ../Detectors to Detectors in ObjectDetection Node, added empty DetectorBase.py * Added DetectorBase Abstract Class for plugin architecture * added configurable params in .yaml file, added show_fps and is_cuda to Detector constructors except EfficientDet, moved BaseDetector out of Detector directory * Updated DetectorBase abstract class and Yolov5 Plugin * Updated adhesive model color * Used getattr() to get detector instance * added launch file and params.yaml (#14) * added launch file and params.yaml * migrated from garden --> fortress * Create detector_plugin.md * Fixed Ignition GUI in world file * Bounding box around objects for image pub (#19) * Cleared prediction list in get_predictions() * Added bounding box around detected object * Update ObjectDetection.py Removed printing of model path * added output=screen in launch file --------- Co-authored-by: topguns837 <[email protected]> Co-authored-by: Arjun K Haridas <[email protected]> Co-authored-by: inferno2211 <[email protected]> Co-authored-by: Abir Thakur <[email protected]> Co-authored-by: topguns837 <[email protected]>
- Loading branch information
1 parent
58c947f
commit 958c2cb
Showing
12 changed files
with
306 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Detector Plugin Architecture | ||
|
||
The `object_detection` package follows a plugin-based architecture for allowing the use of different object detection models. These can be loaded at launch time by setting the [`detector_type`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/2c8152d6a5ae5b5f6e3541648ae97d9ba79ac6a9/object_detection/config/params.yaml#L7P) | ||
param in the `config/params.yaml` file of the package. Currently the package supports the following detectors out of the box: | ||
* YOLOv5 | ||
* YOLOv8 | ||
* RetinaNET | ||
* EdgeDET | ||
|
||
The package provides a [`DetectorBase`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/detector_plugin_architecture/object_detection/object_detection/DetectorBase.py) class which is an abstract class. | ||
It uses the python's in-built `abc.ABC` to define the abstract class and `abc.abstractmethod` decorator to define the blueprint for different class methods that the plugin should implement. | ||
All the detector plugin classes are stored in the [Detectors](https://github.com/atom-robotics-lab/ros-perception-pipeline/tree/detector_plugin_architecture/object_detection/object_detection/Detectors) directory of the | ||
`object_detection` package. | ||
|
||
## Creating Your own Detector Plugin | ||
To create your own detector plugin, follow the steps below: | ||
* Create a file for your Detector class inside the [Detectors](https://github.com/atom-robotics-lab/ros-perception-pipeline/tree/detector_plugin_architecture/object_detection/object_detection/Detectors) directory. | ||
* The file should import the `DetectorBase` class from the [`DetectorBase.py`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/detector_plugin_architecture/object_detection/object_detection/DetectorBase.py) module. You can create a class for your Detector in this file which should inherit the `DetectorBase` abstract class. | ||
|
||
> **Note:** The name of the file and class should be the same. This is required in order to allow the [`object_detection node`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/ObjectDetection.py#L79) to load the module and its class using the value of the `detector_type` in the param. | ||
* Inside the class constructor, make sure to call the constructor of the `DetectorBase` class using `super().__init__()`. This initializes an empty `predictions` list that would be used to store the predictions later. (explained below) | ||
|
||
* After this, the Detector plugin class needs to implement the abstract methods listed below. These are defined in the `DetectorBase` class and provide a signature for the function's implementations. These abstract methods act as a standard API between the Detector plugins and the ObjectDetection node. The plugins only need to match the function signature (parameter and return types) to allow ObjectDetection node to use them. | ||
* [`build_model()`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/DetectorBase.py#L21): It takes 2 strings as parameters: `model_dir_path` and `weight_file_name`. `model_dir_path` is the path which contains the model file and class file. The `weight_file_name` is the name of the weights file (like `.onxx` in case of Yolov5 models). This function should return no parameters and is used by the ObjectDetection node [here](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/ObjectDetection.py#L83). While creating the plugin, you need not worry about the parameters as they are provided by the node through the ROS 2 params. You just need to use their values inside the functions according to your Detector's requirements. | ||
|
||
* [`load_classes()`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/DetectorBase.py#L25): This is similar to the `build_model()` function. It should load the classes file as per the requirement using the provided `model_dir_path` parameter. | ||
|
||
* [`get_predictions()`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/DetectorBase.py#L29): This function is [used by the ObjectDetection node](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/ObjectDetection.py#L92) in the subscriber callback to get the predictions for each frame passed. This function should take an opencv image (which is essentially a numpy array) as a parameter and return a list of dictionaries that contain the predictions. This function can implement any kind of checks, formatting, preprocessing, etc. on the frame (see [this](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/Detectors/YOLOv5.py#L131) for example). It only strictly needs to follow the signature described by the abstract method definition in `DetectorBase`. To create the predictions list, the function should call the [`create_predictions_list()`](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/DetectorBase.py#L10) function from the `DetectorBase` class like [this](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/071c63aa4bc71913d4bf5a4c7f9b4fd03b136338/object_detection/object_detection/Detectors/YOLOv5.py#L144). Using the `create_predictions_list()` function is necessary as it arranges the prediction data in a standard format that the `ObjectDetection` node expects. | ||
|
||
> **Note:** For reference you can through the [YOLOv5 Plugin class](https://github.com/atom-robotics-lab/ros-perception-pipeline/blob/detector_plugin_architecture/object_detection/object_detection/Detectors/YOLOv5.py#L131) and how it implements all the abstract methods. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) 2018 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import sys | ||
|
||
from ament_index_python.packages import get_package_share_directory | ||
|
||
from launch import LaunchDescription | ||
from launch.actions import IncludeLaunchDescription, DeclareLaunchArgument | ||
from launch.substitutions import LaunchConfiguration | ||
from launch.launch_description_sources import PythonLaunchDescriptionSource | ||
from launch_ros.actions import Node | ||
|
||
|
||
def generate_launch_description(): | ||
pkg_object_detection = get_package_share_directory("object_detection") | ||
|
||
params = os.path.join( | ||
pkg_object_detection, | ||
'config', | ||
'params.yaml' | ||
) | ||
|
||
node=Node( | ||
package = 'object_detection', | ||
name = 'object_detection', | ||
executable = 'ObjectDetection', | ||
parameters = [params], | ||
output="screen" | ||
) | ||
|
||
|
||
return LaunchDescription([node]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from abc import ABC, abstractmethod | ||
import numpy as np | ||
|
||
|
||
class DetectorBase(ABC): | ||
|
||
def __init__(self) -> None: | ||
self.predictions = [] | ||
|
||
def create_predictions_list(self, class_ids, confidences, boxes): | ||
for i in range(len(class_ids)): | ||
obj_dict = { | ||
"class_id": class_ids[i], | ||
"confidence": confidences[i], | ||
"box": boxes[i] | ||
} | ||
|
||
self.predictions.append(obj_dict) | ||
|
||
@abstractmethod | ||
def build_model(self, model_dir_path: str, weight_file_name: str) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def load_classes(self, model_dir_path: str) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def get_predictions(self, cv_image: np.ndarray) -> list[dict]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.