The object_detection
package follows a plugin-based architecture for allowing the use of different object detection models. These can be loaded at launch time by setting the detector_type
param in the config/params.yaml
file of the package. Currently the package supports the following detectors out of the box:
- YOLOv5
- YOLOv8
- RetinaNET
- EdgeDET
The package provides a DetectorBase
class which is an abstract class.
It uses the python's in-built abc.ABC
to define the abstract class and abc.abstractmethod
decorator to define the blueprint for different class methods that the plugin should implement.
All the detector plugin classes are stored in the Detectors directory of the
object_detection
package.
To create your own detector plugin, follow the steps below:
-
Create a file for your Detector class inside the Detectors directory.
-
The file should import the
DetectorBase
class from theDetectorBase.py
module. You can create a class for your Detector in this file which should inherit theDetectorBase
abstract class.Note: The name of the file and class should be the same. This is required in order to allow the
object_detection node
to load the module and its class using the value of thedetector_type
in the param. -
Inside the class constructor, make sure to call the constructor of the
DetectorBase
class usingsuper().__init__()
. This initializes an emptypredictions
list that would be used to store the predictions later. (explained below) -
After this, the Detector plugin class needs to implement the abstract methods listed below. These are defined in the
DetectorBase
class and provide a signature for the function's implementations. These abstract methods act as a standard API between the Detector plugins and the ObjectDetection node. The plugins only need to match the function signature (parameter and return types) to allow ObjectDetection node to use them.-
build_model()
: It takes 2 strings as parameters:model_dir_path
andweight_file_name
.model_dir_path
is the path which contains the model file and class file. Theweight_file_name
is the name of the weights file (like.onxx
in case of Yolov5 models). This function should return no parameters and is used by the ObjectDetection node here. While creating the plugin, you need not worry about the parameters as they are provided by the node through the ROS 2 params. You just need to use their values inside the functions according to your Detector's requirements. -
load_classes()
: This is similar to thebuild_model()
function. It should load the classes file as per the requirement using the providedmodel_dir_path
parameter. -
get_predictions()
: This function is used by the ObjectDetection node in the subscriber callback to get the predictions for each frame passed. This function should take an opencv image (which is essentially a numpy array) as a parameter and return a list of dictionaries that contain the predictions. This function can implement any kind of checks, formatting, preprocessing, etc. on the frame (see this for example). It only strictly needs to follow the signature described by the abstract method definition inDetectorBase
. To create the predictions list, the function should call thecreate_predictions_list()
function from theDetectorBase
class like this. Using thecreate_predictions_list()
function is necessary as it arranges the prediction data in a standard format that theObjectDetection
node expects.
-
Note: For reference you can through the YOLOv5 Plugin class and how it implements all the abstract methods.