Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Non-Maximum Merging (NMM) to Detections #500

Merged
merged 29 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c78ae33
feat: 🚀 Added Non-Maximum Merging to Detections
Oct 13, 2023
57b12e6
Added __setitem__ to Detections and refactored the object prediction …
Oct 18, 2023
9f22273
Added standard full image inference after sliced inference to increas…
Oct 18, 2023
6f47046
Refactored merging of Detection attributes to better work with np.nda…
Oct 18, 2023
5f0dcc2
Merge branch 'develop' into add_nmm_to_detections to resolve conflicts
Apr 9, 2024
166a8da
Implement Feedback
Apr 11, 2024
b159873
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
May 6, 2024
d7e52be
NMM: Add None-checks, fix area normalization, style
May 6, 2024
bee3252
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
97c4071
NMM: Move detections merge into Detections class.
May 6, 2024
204669b
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 6, 2024
2eb0c7c
Merge remote-tracking branch 'upstream/develop' into add_nmm_to_detec…
LinasKo May 14, 2024
c3b77d0
Rename, remove functions, unit-test & change `merge_object_detection_…
May 14, 2024
8014e88
Test box_non_max_merge
May 14, 2024
26bafec
Test box_non_max_merge, rename threshold,to __init__
May 15, 2024
d2d50fb
renamed bbox -> xyxy
May 15, 2024
2d740bd
fix: merge_object_detection_pair
May 15, 2024
145b5fe
Rename to batch_box_non_max_merge to box_non_max_merge_batch
May 15, 2024
6c40935
box_non_max_merge: use our functions to compute iou
May 15, 2024
53f345e
Minor renaming
May 15, 2024
0e2eec0
Revert np.bool comparisons with `is`
May 15, 2024
559ef90
Simplify box_non_max_merge
May 15, 2024
f8f3647
Removed suprplus NMM code for 20% speedup
May 15, 2024
9024396
Add npt.NDarray[x] types, remove resolution_wh default val
May 17, 2024
6fbca83
Address review comments, simplify merge
May 23, 2024
db1b473
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 23, 2024
0721bc2
Remove _set_at_index
May 23, 2024
530e1d0
Address comments
May 27, 2024
2ee9e08
Renamed to group_overlapping_boxes
May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from supervision.detection.tools.smoother import DetectionsSmoother
from supervision.detection.utils import (
box_iou_batch,
box_non_max_merge,
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
box_non_max_suppression,
calculate_masks_centroids,
clip_boxes,
Expand Down
202 changes: 202 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
from supervision.detection.utils import (
box_iou_batch,
box_non_max_merge,
box_non_max_suppression,
calculate_masks_centroids,
extract_ultralytics_masks,
Expand Down Expand Up @@ -1150,3 +1152,203 @@ def with_nms(
)

return self[indices]

def with_nmm(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Perform non-maximum merging on the current set of object detections.

Args:
threshold (float, optional): The intersection-over-union threshold
to use for non-maximum merging. Defaults to 0.5.
class_agnostic (bool, optional): Whether to perform class-agnostic
non-maximum merging. If True, the class_id of each detection
will be ignored. Defaults to False.

Returns:
Detections: A new Detections object containing the subset of detections
after non-maximum merging.

Raises:
AssertionError: If `confidence` is None or `class_id` is None and
class_agnostic is False.
"""
if len(self) == 0:
return self

assert (
self.confidence is not None
), "Detections confidence must be given for NMM to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
else:
assert self.class_id is not None, (
"Detections class_id must be given for NMM to be executed. If you"
" intended to perform class agnostic NMM set class_agnostic=True."
)
predictions = np.hstack(
(
self.xyxy,
self.confidence.reshape(-1, 1),
self.class_id.reshape(-1, 1),
)
)

merge_groups = box_non_max_merge(
predictions=predictions, iou_threshold=threshold
)

result = []
for merge_group in merge_groups:
unmerged_detections = [self[i] for i in merge_group]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we don't need that list comprehension, just use detections[indexes].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My explanation was wrong.

We're doing this not to copy the result (that's in another case), but to create a list of single-object detections. [Detections, Detections, Detections, ...].

I believe this is the most concise way.

merged_detections = _merge_inner_detections_objects(
unmerged_detections, threshold
)
result.append(merged_detections)

return Detections.merge(result)


def _merge_inner_detection_object_pair(
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
detections_1: Detections, detections_2: Detections
) -> Detections:
"""
Merges two Detections object into a single Detections object.
Assumes each Detections contains exactly one object.

A `winning` detection is determined based on the confidence score of the two
input detections. This winning detection is then used to specify which
`class_id`, `tracker_id`, and `data` to include in the merged Detections object.

The resulting `confidence` of the merged object is calculated by the weighted
contribution of ea detection to the merged object.
The bounding boxes and masks of the two input detections are merged into a
single bounding box and mask, respectively.

Args:
detections_1 (Detections):
The first Detections object
detections_2 (Detections):
The second Detections object

Returns:
Detections: A new Detections object, with merged attributes.

Raises:
ValueError: If the input Detections objects do not have exactly 1 detected
object.

Example:
```python
import cv2
import supervision as sv
from inference import get_model

image = cv2.imread(<SOURCE_IMAGE_PATH>)
model = get_model(model_id="yolov8s-640")

result = model.infer(image)[0]
detections = sv.Detections.from_inference(result)

merged_detections = merge_object_detection_pair(
detections[0], detections[1])
```
"""
if len(detections_1) != 1 or len(detections_2) != 1:
raise ValueError("Both Detections should have exactly 1 detected object.")

_verify_fields_both_defined_or_none(detections_1, detections_2)

if detections_1.confidence is None and detections_2.confidence is None:
merged_confidence = None
else:
area_det1 = (detections_1.xyxy[0][2] - detections_1.xyxy[0][0]) * (
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
detections_1.xyxy[0][3] - detections_1.xyxy[0][1]
)
area_det2 = (detections_2.xyxy[0][2] - detections_2.xyxy[0][0]) * (
detections_2.xyxy[0][3] - detections_2.xyxy[0][1]
)
merged_confidence = (
area_det1 * detections_1.confidence[0]
+ area_det2 * detections_2.confidence[0]
) / (area_det1 + area_det2)
merged_confidence = np.array([merged_confidence])

merged_x1, merged_y1 = np.minimum(
detections_1.xyxy[0][:2], detections_2.xyxy[0][:2]
)
merged_x2, merged_y2 = np.maximum(
detections_1.xyxy[0][2:], detections_2.xyxy[0][2:]
)
merged_xyxy = np.array([[merged_x1, merged_y1, merged_x2, merged_y2]])

if detections_1.mask is None and detections_2.mask is None:
merged_mask = None
else:
merged_mask = np.logical_or(detections_1.mask, detections_2.mask)

if detections_1.confidence is None and detections_2.confidence is None:
winning_det = detections_1
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
elif detections_1.confidence[0] >= detections_2.confidence[0]:
winning_det = detections_1
else:
winning_det = detections_2

winning_class_id = winning_det.class_id
winning_tracker_id = winning_det.tracker_id
winning_data = winning_det.data
LinasKo marked this conversation as resolved.
Show resolved Hide resolved

return Detections(
xyxy=merged_xyxy,
mask=merged_mask,
confidence=merged_confidence,
class_id=winning_class_id,
tracker_id=winning_tracker_id,
data=winning_data,
)


def _merge_inner_detections_objects(
detections: List[Detections], threshold=0.5
) -> Detections:
"""
Given N detections each of length 1 (exactly one object inside), combine them into a
single detection object of length 1. The contained inner object will be the merged
result of all the input detections.

For example, this lets you merge N boxes into one big box, N masks into one mask,
etc.
"""
detections_1 = detections[0]
for detections_2 in detections[1:]:
box_iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy)[0]
if box_iou < threshold:
break
detections_1 = _merge_inner_detection_object_pair(detections_1, detections_2)
return detections_1


def _verify_fields_both_defined_or_none(
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
detections_1: Detections, detections_2: Detections
) -> None:
"""
Verify that for each optional field in the Detections, both instances either have
the field set to None or both have it set to non-None values.

`data` field is ignored.

Raises:
ValueError: If one field is None and the other is not, for any of the fields.
"""
attributes = ["mask", "confidence", "class_id", "tracker_id"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to get that list automatically?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it, it's cumbersome, I'll add the code + tests in a separate PR and we can choose whether to keep it.

for attribute in attributes:
value_1 = getattr(detections_1, attribute)
value_2 = getattr(detections_2, attribute)

if (value_1 is None) != (value_2 is None):
raise ValueError(
f"Field '{attribute}' should be consistently None or not None in both "
"Detections."
)
Loading