Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions supervision/detection/utils/iou_and_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,43 +192,46 @@ def box_iou_batch(
# [0. , 0.47058824]
# ])
```

"""

def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
tx1, ty1, tx2, ty2 = boxes_true.T
dx1, dy1, dx2, dy2 = boxes_detection.T
N, M = boxes_true.shape[0], boxes_detection.shape[0]

top_left_x = np.empty((N, M), dtype=np.float32)
bottom_right_x = np.empty_like(top_left_x)
top_left_y = np.empty_like(top_left_x)
bottom_right_y = np.empty_like(top_left_x)

area_true = box_area(boxes_true.T)
area_detection = box_area(boxes_detection.T)
np.maximum(tx1[:, None], dx1[None, :], out=top_left_x)
np.minimum(tx2[:, None], dx2[None, :], out=bottom_right_x)
np.maximum(ty1[:, None], dy1[None, :], out=top_left_y)
np.minimum(ty2[:, None], dy2[None, :], out=bottom_right_y)

top_left = np.maximum(boxes_true[:, None, :2], boxes_detection[:, :2])
bottom_right = np.minimum(boxes_true[:, None, 2:], boxes_detection[:, 2:])
np.subtract(bottom_right_x, top_left_x, out=bottom_right_x) # W
np.subtract(bottom_right_y, top_left_y, out=bottom_right_y) # H
np.clip(bottom_right_x, 0.0, None, out=bottom_right_x)
np.clip(bottom_right_y, 0.0, None, out=bottom_right_y)

area_inter = np.prod(np.clip(bottom_right - top_left, a_min=0, a_max=None), 2)
area_inter = bottom_right_x * bottom_right_y

area_true = (tx2 - tx1) * (ty2 - ty1)
area_detection = (dx2 - dx1) * (dy2 - dy1)

if overlap_metric == OverlapMetric.IOU:
union_area = area_true[:, None] + area_detection - area_inter
ious = np.divide(
area_inter,
union_area,
out=np.zeros_like(area_inter, dtype=float),
where=union_area != 0,
)
denom = area_true[:, None] + area_detection[None, :] - area_inter
elif overlap_metric == OverlapMetric.IOS:
small_area = np.minimum(area_true[:, None], area_detection)
ious = np.divide(
area_inter,
small_area,
out=np.zeros_like(area_inter, dtype=float),
where=small_area != 0,
)
denom = np.minimum(area_true[:, None], area_detection[None, :])
else:
raise ValueError(
f"overlap_metric {overlap_metric} is not supported, "
"only 'IOU' and 'IOS' are supported"
)

ious = np.nan_to_num(ious)
return ious
out = np.zeros_like(area_inter, dtype=np.float32)
np.divide(area_inter, denom, out=out, where=denom > 0)
return out


def _jaccard(box_a: list[float], box_b: list[float], is_crowd: bool) -> float:
Expand Down
105 changes: 105 additions & 0 deletions test/detection/utils/test_iou_and_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

from supervision.detection.utils.iou_and_nms import (
_group_overlapping_boxes,
box_iou,
box_iou_batch,
box_non_max_suppression,
mask_non_max_merge,
mask_non_max_suppression,
)
from test.test_utils import mock_boxes


@pytest.mark.parametrize(
Expand Down Expand Up @@ -631,3 +634,105 @@ def test_mask_non_max_merge(
sorted_result = sorted([sorted(group) for group in result])
sorted_expected_result = sorted([sorted(group) for group in expected_result])
assert sorted_result == sorted_expected_result


@pytest.mark.parametrize(
"boxes_true, boxes_detection, expected_iou, exception",
[
(
np.empty((0, 4), dtype=np.float32),
np.empty((0, 4), dtype=np.float32),
np.empty((0, 0), dtype=np.float32),
DoesNotRaise(),
), # empty
(
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.empty((0, 4), dtype=np.float32),
np.empty((1, 0), dtype=np.float32),
DoesNotRaise(),
), # one true box, no detections
(
np.empty((0, 4), dtype=np.float32),
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.empty((0, 1), dtype=np.float32),
DoesNotRaise(),
), # no true boxes, one detection
(
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.array([[1.0]]),
DoesNotRaise(),
), # perfect overlap
(
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.array([[20, 20, 30, 30]], dtype=np.float32),
np.array([[0.0]]),
DoesNotRaise(),
), # no overlap
(
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.array([[5, 5, 15, 15]], dtype=np.float32),
np.array([[25.0 / 175.0]]), # intersection: 5x5=25, union: 100+100-25=175
DoesNotRaise(),
), # partial overlap
(
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.array([[0, 0, 5, 5]], dtype=np.float32),
np.array([[25.0 / 100.0]]), # intersection: 5x5=25, union: 100
DoesNotRaise(),
), # detection inside true box
(
np.array([[0, 0, 5, 5]], dtype=np.float32),
np.array([[0, 0, 10, 10]], dtype=np.float32),
np.array([[25.0 / 100.0]]), # true box inside detection
DoesNotRaise(),
),
(
np.array([[0, 0, 10, 10], [20, 20, 30, 30]], dtype=np.float32),
np.array([[0, 0, 10, 10], [20, 20, 30, 30]], dtype=np.float32),
np.array([[1.0, 0.0], [0.0, 1.0]]),
DoesNotRaise(),
), # two boxes, perfect matches
],
)
def test_box_iou_batch(
boxes_true: np.ndarray,
boxes_detection: np.ndarray,
expected_iou: np.ndarray,
exception: Exception,
) -> None:
with exception:
result = box_iou_batch(boxes_true, boxes_detection)
assert result.shape == expected_iou.shape
assert np.allclose(result, expected_iou, rtol=1e-5, atol=1e-5)


def test_box_iou_batch_consistency_with_box_iou():
"""Test that box_iou_batch gives same results as box_iou for single boxes."""
boxes_true = np.array(mock_boxes(5, seed=1), dtype=np.float32)
boxes_detection = np.array(mock_boxes(5, seed=2), dtype=np.float32)

batch_result = box_iou_batch(boxes_true, boxes_detection)

for i, box_true in enumerate(boxes_true):
for j, box_detection in enumerate(boxes_detection):
single_result = box_iou(box_true, box_detection)
assert np.allclose(batch_result[i, j], single_result, rtol=1e-5, atol=1e-5)


def test_box_iou_batch_with_mock_detections():
"""Test box_iou_batch with generated boxes and verify results are valid."""
boxes_true = np.array(mock_boxes(10, seed=1), dtype=np.float32)
boxes_detection = np.array(mock_boxes(15, seed=2), dtype=np.float32)

result = box_iou_batch(boxes_true, boxes_detection)

assert result.shape == (10, 15)

assert np.all(result >= 0)
assert np.all(result <= 1.0)

# and symmetric
result_reversed = box_iou_batch(boxes_detection, boxes_true)
assert result_reversed.shape == (15, 10)
assert np.allclose(result.T, result_reversed, rtol=1e-5, atol=1e-5)
36 changes: 36 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import random
from typing import Any

import numpy as np
Expand Down Expand Up @@ -52,5 +53,40 @@ def convert_data(data: dict[str, list[Any]]):
)


def mock_boxes(
n: int,
resolution_wh: tuple[int, int] = (1920, 1080),
min_size: int = 20,
max_size: int = 200,
seed: int | None = None,
) -> list[list[float]]:
"""
Generate N valid bounding boxes of format [x_min, y_min, x_max, y_max].

Args:
n: Number of boxes to generate.
resolution_wh: Image resolution as (width, height). Defaults to (1920, 1080).
min_size: Minimum box size (width/height). Defaults to 20.
max_size: Maximum box size (width/height). Defaults to 200.
seed: Random seed for reproducibility. Defaults to None.

Returns:
List of boxes, each as [x_min, y_min, x_max, y_max].
"""
if seed is not None:
random.seed(seed)
width, height = resolution_wh
boxes = []
for _ in range(n):
w = random.uniform(min_size, max_size)
h = random.uniform(min_size, max_size)
x1 = random.uniform(0, width - w)
y1 = random.uniform(0, height - h)
x2 = x1 + w
y2 = y1 + h
boxes.append([x1, y1, x2, y2])
return boxes


def assert_almost_equal(actual, expected, tolerance=1e-5):
assert abs(actual - expected) < tolerance, f"Expected {expected}, but got {actual}."