Skip to content

Commit

Permalink
Fixed Issue #55 (Implemented Confidence Threshold in ObjectDetection.…
Browse files Browse the repository at this point in the history
…py) (#61)

* Fixed Issue #55

* minor fix in RetinaNet.py constructor

* Removed Confidence Threshold Parameter from YOLOv8.py

* fixed linting issues in yolov8.py rev0

* fixed linting issues in yolov8.py rev1

* fixed linting issues rev2

* fixed linting issues rev3

* fixed linting issues rev4
  • Loading branch information
sudo-deep authored Feb 14, 2024
1 parent eaadc67 commit 6cf77a1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion object_detection/object_detection/Detectors/RetinaNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class RetinaNet(DetectorBase):

def __init(self):
def __init__(self):
super.__init__()

def build_model(self, model_dir_path, weight_file_name):
Expand Down
6 changes: 3 additions & 3 deletions object_detection/object_detection/Detectors/YOLOv8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

class YOLOv8(DetectorBase):

def __init__(self, conf_threshold=0.7):
def __init__(self):

super().__init__()
self.conf_threshold = conf_threshold

def build_model(self, model_dir_path, weight_file_name):
try:
Expand Down Expand Up @@ -53,7 +53,7 @@ def get_predictions(self, cv_image):
boxes = []

# Perform object detection on image
result = self.model.predict(self.frame, conf=self.conf_threshold, verbose=False)
result = self.model.predict(self.frame, verbose=False) # Perform object detection on image
row = result[0].boxes.cpu()

for box in row:
Expand Down
22 changes: 13 additions & 9 deletions object_detection/object_detection/ObjectDetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,23 @@ def detection_cb(self, img_msg):
print("Image input from topic: {} is empty".format(self.input_img_topic))
else:
for prediction in predictions:
x1, y1, x2, y2 = map(int, prediction['box'])
confidence = prediction['confidence']

# Draw the bounding box
cv_image = cv2.rectangle(cv_image, (x1, y1), (x2, y2), (0, 255, 0), 1)
# Check if the confidence is above the threshold
if confidence >= self.confidence_threshold:
x1, y1, x2, y2 = map(int, prediction['box'])

# Show names of classes on the output image
class_id = int(prediction['class_id'])
class_name = self.detector.class_list[class_id]
label = f"{class_name}: {prediction['confidence']:.2f}"
# Draw the bounding box
cv_image = cv2.rectangle(cv_image, (x1, y1), (x2, y2), (0, 255, 0), 1)

cv_image = cv2.putText(cv_image, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
# Show names of classes on the output image
class_id = int(prediction['class_id'])
class_name = self.detector.class_list[class_id]
label = f"{class_name} : {confidence:.2f}"

cv_image = cv2.putText(cv_image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)

# Publish the modified image
output = self.bridge.cv2_to_imgmsg(cv_image, "bgr8")
self.img_pub.publish(output)

Expand Down

0 comments on commit 6cf77a1

Please sign in to comment.