diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index 56420ed6e..85b741c35 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -292,7 +292,7 @@ def box_non_max_merge( Dict[int, List[int]]: Mapping from prediction indices to keep to a list of prediction indices to be merged. """ - keep_to_merge_list = {} + keep_to_merge_list: Dict[int, List[int]] = {} scores = predictions[:, 4] order = scores.argsort() @@ -307,17 +307,11 @@ def box_non_max_merge( break ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4]) + ious = ious.flatten() - below_threshold = (ious < iou_threshold).astype(np.uint8) - matched_box_indices = np.flip(order[np.where(below_threshold == 0)[0]]) - unmatched_indices = order[np.where(below_threshold == 1)[0]] - - order = unmatched_indices[scores[unmatched_indices].argsort()] - - keep_to_merge_list[idx.tolist()] = [] - - for matched_box_ind in matched_box_indices.tolist(): - keep_to_merge_list[idx.tolist()].append(matched_box_ind) + above_threshold = ious >= iou_threshold + keep_to_merge_list[idx] = np.flip(order[above_threshold]).tolist() + order = order[~above_threshold] return keep_to_merge_list