Skip to content
90 changes: 90 additions & 0 deletions supervision/detection/utils/iou_and_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,96 @@ def box_area(box):
return ious


def box_iou_batch_alt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename it to box_iou_batch and remove old implementation of this function.

boxes_true: np.ndarray,
boxes_detection: np.ndarray,
overlap_metric: OverlapMetric = OverlapMetric.IOU,
) -> np.ndarray:
"""
Compute Intersection over Union (IoU) of two sets of bounding boxes -
`boxes_true` and `boxes_detection`. Both sets
of boxes are expected to be in `(x_min, y_min, x_max, y_max)` format.

Note:
Use `box_iou` when computing IoU between two individual boxes.
For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better
performance.

Args:
boxes_true (np.ndarray): 2D `np.ndarray` representing ground-truth boxes.
`shape = (N, 4)` where `N` is number of true objects.
boxes_detection (np.ndarray): 2D `np.ndarray` representing detection boxes.
`shape = (M, 4)` where `M` is number of detected objects.
overlap_metric (OverlapMetric): Metric used to compute the degree of overlap
between pairs of boxes (e.g., IoU, IoS).

Returns:
np.ndarray: Pairwise IoU of boxes from `boxes_true` and `boxes_detection`.
`shape = (N, M)` where `N` is number of true objects and
`M` is number of detected objects.

Examples:
```python
import numpy as np
import supervision as sv

boxes_true = np.array([
[100, 100, 200, 200],
[300, 300, 400, 400]
])
boxes_detection = np.array([
[150, 150, 250, 250],
[320, 320, 420, 420]
])

sv.box_iou_batch(boxes_true=boxes_true, boxes_detection=boxes_detection)
# array([
# [0.14285714, 0. ],
# [0. , 0.47058824]
# ])
```

"""

tx1, ty1, tx2, ty2 = boxes_true.T
dx1, dy1, dx2, dy2 = boxes_detection.T
N, M = boxes_true.shape[0], boxes_detection.shape[0]

top_left_x = np.empty((N, M), dtype=np.float32)
bottom_right_x = np.empty_like(top_left_x)
top_left_y = np.empty_like(top_left_x)
bottom_right_y = np.empty_like(top_left_x)

np.maximum(tx1[:, None], dx1[None, :], out=top_left_x)
np.minimum(tx2[:, None], dx2[None, :], out=bottom_right_x)
np.maximum(ty1[:, None], dy1[None, :], out=top_left_y)
np.minimum(ty2[:, None], dy2[None, :], out=bottom_right_y)

np.subtract(bottom_right_x, top_left_x, out=bottom_right_x) # W
np.subtract(bottom_right_y, top_left_y, out=bottom_right_y) # H
np.clip(bottom_right_x, 0.0, None, out=bottom_right_x)
np.clip(bottom_right_y, 0.0, None, out=bottom_right_y)

area_inter = bottom_right_x * bottom_right_y

area_true = (tx2 - tx1) * (ty2 - ty1)
area_detection = (dx2 - dx1) * (dy2 - dy1)

if overlap_metric == OverlapMetric.IOU:
denom = area_true[:, None] + area_detection[None, :] - area_inter
elif overlap_metric == OverlapMetric.IOS:
denom = np.minimum(area_true[:, None], area_detection[None, :])
else:
raise ValueError(
f"overlap_metric {overlap_metric} is not supported, "
"only 'IOU' and 'IOS' are supported"
)

out = np.zeros_like(area_inter, dtype=np.float32)
np.divide(area_inter, denom, out=out, where=denom > 0)
return out


def _jaccard(box_a: list[float], box_b: list[float], is_crowd: bool) -> float:
"""
Calculate the Jaccard index (intersection over union) between two bounding boxes.
Expand Down
38 changes: 38 additions & 0 deletions test/detection/utils/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import random

import numpy as np


def generate_boxes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In supervision/test/test_utils.py, there are already two functions, mock_detections and mock_key_points. It seems we are duplicating some logic from mock_detections. Try to unify them.

If that’s not possible, at least move generate_boxes to supervision/test/test_utils.py and make its arguments and naming conventions consistent with the existing functions.

One quick improvement would be to replace the separate W and H arguments with a single resolution_wh argument, which we use throughout the codebase.

n: int,
W: int = 1920,
H: int = 1080,
min_size: int = 20,
max_size: int = 200,
seed: int | None = 1,
):
"""
Generate N valid bounding boxes of format [x_min, y_min, x_max, y_max].
Args:
n (int): Number of boexs to generate
W (int): Image width
H (int): Image height
min_size (int): Minimum box size (width/height)
max_size (int): Maximum box size (width/height)
seed (int | None): Random seed for reproducibility
Returns:
list[list[float]] | np.ndarray: List of boxes
"""
random.seed(seed)
boxes = []
for _ in range(n):
w = random.uniform(min_size, max_size)
h = random.uniform(min_size, max_size)
x1 = random.uniform(0, W - w)
y1 = random.uniform(0, H - h)
x2 = x1 + w
y2 = y1 + h
boxes.append([x1, y1, x2, y2])
return np.array(boxes, dtype=np.float32)
14 changes: 14 additions & 0 deletions test/detection/utils/test_iou_and_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

from supervision.detection.utils.iou_and_nms import (
_group_overlapping_boxes,
box_iou_batch,
box_iou_batch_alt,
box_non_max_suppression,
mask_non_max_merge,
mask_non_max_suppression,
)
from test.detection.utils.functions import generate_boxes


@pytest.mark.parametrize(
Expand Down Expand Up @@ -631,3 +634,14 @@ def test_mask_non_max_merge(
sorted_result = sorted([sorted(group) for group in result])
sorted_expected_result = sorted([sorted(group) for group in expected_result])
assert sorted_result == sorted_expected_result


def test_box_iou_batch_and_alt_equivalence():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we will rename box_iou_batch_alt to box_iou_batch that test no longer will be necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added additional test cases rather than using the trivial approach of comparing against the original implementation.

boxes_true = generate_boxes(20, seed=1)
boxes_detection = generate_boxes(30, seed=2)

iou_a = box_iou_batch(boxes_true, boxes_detection)
iou_b = box_iou_batch_alt(boxes_true, boxes_detection)

assert iou_a.shape == iou_b.shape
assert np.allclose(iou_a, iou_b, rtol=1e-6, atol=1e-6)