Skip to content

Commit 35ff87b

Browse files
committed
feat: speed up box iou batch using 2d in-place ops instead of 3d
without creating intermediary (N, M, 2) arrays.
1 parent 171687f commit 35ff87b

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed

supervision/detection/utils/iou_and_nms.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,96 @@ def box_area(box):
231231
return ious
232232

233233

234+
def box_iou_batch_alt(
235+
boxes_true: np.ndarray,
236+
boxes_detection: np.ndarray,
237+
overlap_metric: OverlapMetric = OverlapMetric.IOU,
238+
) -> np.ndarray:
239+
"""
240+
Compute Intersection over Union (IoU) of two sets of bounding boxes -
241+
`boxes_true` and `boxes_detection`. Both sets
242+
of boxes are expected to be in `(x_min, y_min, x_max, y_max)` format.
243+
244+
Note:
245+
Use `box_iou` when computing IoU between two individual boxes.
246+
For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better
247+
performance.
248+
249+
Args:
250+
boxes_true (np.ndarray): 2D `np.ndarray` representing ground-truth boxes.
251+
`shape = (N, 4)` where `N` is number of true objects.
252+
boxes_detection (np.ndarray): 2D `np.ndarray` representing detection boxes.
253+
`shape = (M, 4)` where `M` is number of detected objects.
254+
overlap_metric (OverlapMetric): Metric used to compute the degree of overlap
255+
between pairs of boxes (e.g., IoU, IoS).
256+
257+
Returns:
258+
np.ndarray: Pairwise IoU of boxes from `boxes_true` and `boxes_detection`.
259+
`shape = (N, M)` where `N` is number of true objects and
260+
`M` is number of detected objects.
261+
262+
Examples:
263+
```python
264+
import numpy as np
265+
import supervision as sv
266+
267+
boxes_true = np.array([
268+
[100, 100, 200, 200],
269+
[300, 300, 400, 400]
270+
])
271+
boxes_detection = np.array([
272+
[150, 150, 250, 250],
273+
[320, 320, 420, 420]
274+
])
275+
276+
sv.box_iou_batch(boxes_true=boxes_true, boxes_detection=boxes_detection)
277+
# array([
278+
# [0.14285714, 0. ],
279+
# [0. , 0.47058824]
280+
# ])
281+
```
282+
283+
"""
284+
285+
tx1, ty1, tx2, ty2 = boxes_true.T
286+
dx1, dy1, dx2, dy2 = boxes_detection.T
287+
N, M = boxes_true.shape[0], boxes_detection.shape[0]
288+
289+
top_left_x = np.empty((N, M), dtype=np.float32)
290+
bottom_right_x = np.empty_like(top_left_x)
291+
top_left_y = np.empty_like(top_left_x)
292+
bottom_right_y = np.empty_like(top_left_x)
293+
294+
np.maximum(tx1[:, None], dx1[None, :], out=top_left_x)
295+
np.minimum(tx2[:, None], dx2[None, :], out=bottom_right_x)
296+
np.maximum(ty1[:, None], dy1[None, :], out=top_left_y)
297+
np.minimum(ty2[:, None], dy2[None, :], out=bottom_right_y)
298+
299+
np.subtract(bottom_right_x, top_left_x, out=bottom_right_x) # W
300+
np.subtract(bottom_right_y, top_left_y, out=bottom_right_y) # H
301+
np.clip(bottom_right_x, 0.0, None, out=bottom_right_x)
302+
np.clip(bottom_right_y, 0.0, None, out=bottom_right_y)
303+
304+
area_inter = bottom_right_x * bottom_right_y
305+
306+
area_true = (tx2 - tx1) * (ty2 - ty1)
307+
area_detection = (dx2 - dx1) * (dy2 - dy1)
308+
309+
if overlap_metric == OverlapMetric.IOU:
310+
denom = area_true[:, None] + area_detection[None, :] - area_inter
311+
elif overlap_metric == OverlapMetric.IOS:
312+
denom = np.minimum(area_true[:, None], area_detection[None, :])
313+
else:
314+
raise ValueError(
315+
f"overlap_metric {overlap_metric} is not supported, "
316+
"only 'IOU' and 'IOS' are supported"
317+
)
318+
319+
out = np.zeros_like(area_inter, dtype=np.float32)
320+
np.divide(area_inter, denom, out=out, where=denom > 0)
321+
return out
322+
323+
234324
def _jaccard(box_a: list[float], box_b: list[float], is_crowd: bool) -> float:
235325
"""
236326
Calculate the Jaccard index (intersection over union) between two bounding boxes.

test/detection/utils/functions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import random
2+
3+
import numpy as np
4+
5+
6+
def generate_boxes(
7+
n: int,
8+
W: int = 1920,
9+
H: int = 1080,
10+
min_size: int = 20,
11+
max_size: int = 200,
12+
seed: int | None = 1,
13+
):
14+
"""
15+
Generate N valid bounding boxes of format [x_min, y_min, x_max, y_max].
16+
17+
Args:
18+
n (int): Number of boexs to generate
19+
W (int): Image width
20+
H (int): Image height
21+
min_size (int): Minimum box size (width/height)
22+
max_size (int): Maximum box size (width/height)
23+
seed (int | None): Random seed for reproducibility
24+
25+
Returns:
26+
list[list[float]] | np.ndarray: List of boxes
27+
"""
28+
random.seed(seed)
29+
boxes = []
30+
for _ in range(n):
31+
w = random.uniform(min_size, max_size)
32+
h = random.uniform(min_size, max_size)
33+
x1 = random.uniform(0, W - w)
34+
y1 = random.uniform(0, H - h)
35+
x2 = x1 + w
36+
y2 = y1 + h
37+
boxes.append([x1, y1, x2, y2])
38+
return np.array(boxes, dtype=np.float32)

test/detection/utils/test_iou_and_nms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77

88
from supervision.detection.utils.iou_and_nms import (
99
_group_overlapping_boxes,
10+
box_iou_batch,
11+
box_iou_batch_alt,
1012
box_non_max_suppression,
1113
mask_non_max_merge,
1214
mask_non_max_suppression,
1315
)
16+
from test.detection.utils.functions import generate_boxes
1417

1518

1619
@pytest.mark.parametrize(
@@ -631,3 +634,14 @@ def test_mask_non_max_merge(
631634
sorted_result = sorted([sorted(group) for group in result])
632635
sorted_expected_result = sorted([sorted(group) for group in expected_result])
633636
assert sorted_result == sorted_expected_result
637+
638+
639+
def test_box_iou_batch_and_alt_equivalence():
640+
boxes_true = generate_boxes(20, seed=1)
641+
boxes_detection = generate_boxes(30, seed=2)
642+
643+
iou_a = box_iou_batch(boxes_true, boxes_detection)
644+
iou_b = box_iou_batch_alt(boxes_true, boxes_detection)
645+
646+
assert iou_a.shape == iou_b.shape
647+
assert np.allclose(iou_a, iou_b, rtol=1e-6, atol=1e-6)

0 commit comments

Comments
 (0)