Skip to content

Commit

Permalink
Simplify box_non_max_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Linas Kondrackis committed May 15, 2024
1 parent 0e2eec0 commit 559ef90
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def box_non_max_merge(
Dict[int, List[int]]: Mapping from prediction indices
to keep to a list of prediction indices to be merged.
"""
keep_to_merge_list = {}
keep_to_merge_list: Dict[int, List[int]] = {}

scores = predictions[:, 4]
order = scores.argsort()
Expand All @@ -307,17 +307,11 @@ def box_non_max_merge(
break

ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
ious = ious.flatten()

below_threshold = (ious < iou_threshold).astype(np.uint8)
matched_box_indices = np.flip(order[np.where(below_threshold == 0)[0]])
unmatched_indices = order[np.where(below_threshold == 1)[0]]

order = unmatched_indices[scores[unmatched_indices].argsort()]

keep_to_merge_list[idx.tolist()] = []

for matched_box_ind in matched_box_indices.tolist():
keep_to_merge_list[idx.tolist()].append(matched_box_ind)
above_threshold = ious >= iou_threshold
keep_to_merge_list[idx] = np.flip(order[above_threshold]).tolist()
order = order[~above_threshold]

return keep_to_merge_list

Expand Down

0 comments on commit 559ef90

Please sign in to comment.