diff --git a/object_detection/object_detection/ObjectDetection.py b/object_detection/object_detection/ObjectDetection.py index 5ac3b48..7221986 100644 --- a/object_detection/object_detection/ObjectDetection.py +++ b/object_detection/object_detection/ObjectDetection.py @@ -65,12 +65,8 @@ def __init__(self): self.confidence_threshold = self.get_parameter('model_params.confidence_threshold').value self.show_fps = self.get_parameter('model_params.show_fps').value - # raise an exception if specified detector was not found - if self.detector_type not in self.available_detectors: - raise ModuleNotFoundError(self.detector_type + " Detector specified in config was not found. " + - "Check the Detectors dir for available detectors.") - else: - self.load_detector() + # Load the detector + self.load_detector() self.img_pub = self.create_publisher(Image, self.output_img_topic, 10) self.bb_pub = None @@ -95,15 +91,23 @@ def discover_detectors(self): self.available_detectors.remove('__init__') def load_detector(self): - detector_mod = importlib.import_module(".Detectors." + self.detector_type, - "object_detection") - detector_class = getattr(detector_mod, self.detector_type) - self.detector = detector_class() + for detector_name in self.available_detectors: + if self.detector_type.lower() == detector_name.lower(): + + detector_mod = importlib.import_module(".Detectors." + detector_name, + "object_detection") + detector_class = getattr(detector_mod, detector_name) + self.detector = detector_class() - self.detector.build_model(self.model_dir_path, self.weight_file_name) - self.detector.load_classes(self.model_dir_path) + self.detector.build_model(self.model_dir_path, self.weight_file_name) + self.detector.load_classes(self.model_dir_path) - print("Your detector: {} has been loaded !".format(self.detector_type)) + print("Your detector: {} has been loaded !".format(detector_name)) + return + + raise ModuleNotFoundError(self.detector_type + " Detector specified in config was not found. " + + "Check the Detectors dir for available detectors.") + def detection_cb(self, img_msg): cv_image = self.bridge.imgmsg_to_cv2(img_msg, "bgr8")