Skip to content

Commit

Permalink
Made the detector_type param case-insensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
topguns837 committed Feb 14, 2024
1 parent 640747f commit 06b7189
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions object_detection/object_detection/ObjectDetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ def __init__(self):
self.confidence_threshold = self.get_parameter('model_params.confidence_threshold').value
self.show_fps = self.get_parameter('model_params.show_fps').value

# raise an exception if specified detector was not found
if self.detector_type not in self.available_detectors:
raise ModuleNotFoundError(self.detector_type + " Detector specified in config was not found. " +
"Check the Detectors dir for available detectors.")
else:
self.load_detector()
# Load the detector
self.load_detector()

self.img_pub = self.create_publisher(Image, self.output_img_topic, 10)
self.bb_pub = None
Expand All @@ -95,15 +91,23 @@ def discover_detectors(self):
self.available_detectors.remove('__init__')

def load_detector(self):
detector_mod = importlib.import_module(".Detectors." + self.detector_type,
"object_detection")
detector_class = getattr(detector_mod, self.detector_type)
self.detector = detector_class()
for detector_name in self.available_detectors:
if self.detector_type.lower() == detector_name.lower():

detector_mod = importlib.import_module(".Detectors." + detector_name,
"object_detection")
detector_class = getattr(detector_mod, detector_name)
self.detector = detector_class()

self.detector.build_model(self.model_dir_path, self.weight_file_name)
self.detector.load_classes(self.model_dir_path)
self.detector.build_model(self.model_dir_path, self.weight_file_name)
self.detector.load_classes(self.model_dir_path)

print("Your detector: {} has been loaded !".format(self.detector_type))
print("Your detector: {} has been loaded !".format(detector_name))
return

raise ModuleNotFoundError(self.detector_type + " Detector specified in config was not found. " +
"Check the Detectors dir for available detectors.")


def detection_cb(self, img_msg):
cv_image = self.bridge.imgmsg_to_cv2(img_msg, "bgr8")
Expand Down

0 comments on commit 06b7189

Please sign in to comment.