From 35ff87b08d88a18a855c5325b97fe8fa3072adb0 Mon Sep 17 00:00:00 2001 From: AnonymDevOSS Date: Tue, 28 Oct 2025 13:51:28 +0100 Subject: [PATCH 1/4] feat: speed up box iou batch using 2d in-place ops instead of 3d without creating intermediary (N, M, 2) arrays. --- supervision/detection/utils/iou_and_nms.py | 90 ++++++++++++++++++++++ test/detection/utils/functions.py | 38 +++++++++ test/detection/utils/test_iou_and_nms.py | 14 ++++ 3 files changed, 142 insertions(+) create mode 100644 test/detection/utils/functions.py diff --git a/supervision/detection/utils/iou_and_nms.py b/supervision/detection/utils/iou_and_nms.py index 1a6f80bc5..b55eea876 100644 --- a/supervision/detection/utils/iou_and_nms.py +++ b/supervision/detection/utils/iou_and_nms.py @@ -231,6 +231,96 @@ def box_area(box): return ious +def box_iou_batch_alt( + boxes_true: np.ndarray, + boxes_detection: np.ndarray, + overlap_metric: OverlapMetric = OverlapMetric.IOU, +) -> np.ndarray: + """ + Compute Intersection over Union (IoU) of two sets of bounding boxes - + `boxes_true` and `boxes_detection`. Both sets + of boxes are expected to be in `(x_min, y_min, x_max, y_max)` format. + + Note: + Use `box_iou` when computing IoU between two individual boxes. + For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better + performance. + + Args: + boxes_true (np.ndarray): 2D `np.ndarray` representing ground-truth boxes. + `shape = (N, 4)` where `N` is number of true objects. + boxes_detection (np.ndarray): 2D `np.ndarray` representing detection boxes. + `shape = (M, 4)` where `M` is number of detected objects. + overlap_metric (OverlapMetric): Metric used to compute the degree of overlap + between pairs of boxes (e.g., IoU, IoS). + + Returns: + np.ndarray: Pairwise IoU of boxes from `boxes_true` and `boxes_detection`. + `shape = (N, M)` where `N` is number of true objects and + `M` is number of detected objects. + + Examples: + ```python + import numpy as np + import supervision as sv + + boxes_true = np.array([ + [100, 100, 200, 200], + [300, 300, 400, 400] + ]) + boxes_detection = np.array([ + [150, 150, 250, 250], + [320, 320, 420, 420] + ]) + + sv.box_iou_batch(boxes_true=boxes_true, boxes_detection=boxes_detection) + # array([ + # [0.14285714, 0. ], + # [0. , 0.47058824] + # ]) + ``` + + """ + + tx1, ty1, tx2, ty2 = boxes_true.T + dx1, dy1, dx2, dy2 = boxes_detection.T + N, M = boxes_true.shape[0], boxes_detection.shape[0] + + top_left_x = np.empty((N, M), dtype=np.float32) + bottom_right_x = np.empty_like(top_left_x) + top_left_y = np.empty_like(top_left_x) + bottom_right_y = np.empty_like(top_left_x) + + np.maximum(tx1[:, None], dx1[None, :], out=top_left_x) + np.minimum(tx2[:, None], dx2[None, :], out=bottom_right_x) + np.maximum(ty1[:, None], dy1[None, :], out=top_left_y) + np.minimum(ty2[:, None], dy2[None, :], out=bottom_right_y) + + np.subtract(bottom_right_x, top_left_x, out=bottom_right_x) # W + np.subtract(bottom_right_y, top_left_y, out=bottom_right_y) # H + np.clip(bottom_right_x, 0.0, None, out=bottom_right_x) + np.clip(bottom_right_y, 0.0, None, out=bottom_right_y) + + area_inter = bottom_right_x * bottom_right_y + + area_true = (tx2 - tx1) * (ty2 - ty1) + area_detection = (dx2 - dx1) * (dy2 - dy1) + + if overlap_metric == OverlapMetric.IOU: + denom = area_true[:, None] + area_detection[None, :] - area_inter + elif overlap_metric == OverlapMetric.IOS: + denom = np.minimum(area_true[:, None], area_detection[None, :]) + else: + raise ValueError( + f"overlap_metric {overlap_metric} is not supported, " + "only 'IOU' and 'IOS' are supported" + ) + + out = np.zeros_like(area_inter, dtype=np.float32) + np.divide(area_inter, denom, out=out, where=denom > 0) + return out + + def _jaccard(box_a: list[float], box_b: list[float], is_crowd: bool) -> float: """ Calculate the Jaccard index (intersection over union) between two bounding boxes. diff --git a/test/detection/utils/functions.py b/test/detection/utils/functions.py new file mode 100644 index 000000000..6b10dfa2f --- /dev/null +++ b/test/detection/utils/functions.py @@ -0,0 +1,38 @@ +import random + +import numpy as np + + +def generate_boxes( + n: int, + W: int = 1920, + H: int = 1080, + min_size: int = 20, + max_size: int = 200, + seed: int | None = 1, +): + """ + Generate N valid bounding boxes of format [x_min, y_min, x_max, y_max]. + + Args: + n (int): Number of boexs to generate + W (int): Image width + H (int): Image height + min_size (int): Minimum box size (width/height) + max_size (int): Maximum box size (width/height) + seed (int | None): Random seed for reproducibility + + Returns: + list[list[float]] | np.ndarray: List of boxes + """ + random.seed(seed) + boxes = [] + for _ in range(n): + w = random.uniform(min_size, max_size) + h = random.uniform(min_size, max_size) + x1 = random.uniform(0, W - w) + y1 = random.uniform(0, H - h) + x2 = x1 + w + y2 = y1 + h + boxes.append([x1, y1, x2, y2]) + return np.array(boxes, dtype=np.float32) diff --git a/test/detection/utils/test_iou_and_nms.py b/test/detection/utils/test_iou_and_nms.py index 8039bf242..87fd958ad 100644 --- a/test/detection/utils/test_iou_and_nms.py +++ b/test/detection/utils/test_iou_and_nms.py @@ -7,10 +7,13 @@ from supervision.detection.utils.iou_and_nms import ( _group_overlapping_boxes, + box_iou_batch, + box_iou_batch_alt, box_non_max_suppression, mask_non_max_merge, mask_non_max_suppression, ) +from test.detection.utils.functions import generate_boxes @pytest.mark.parametrize( @@ -631,3 +634,14 @@ def test_mask_non_max_merge( sorted_result = sorted([sorted(group) for group in result]) sorted_expected_result = sorted([sorted(group) for group in expected_result]) assert sorted_result == sorted_expected_result + + +def test_box_iou_batch_and_alt_equivalence(): + boxes_true = generate_boxes(20, seed=1) + boxes_detection = generate_boxes(30, seed=2) + + iou_a = box_iou_batch(boxes_true, boxes_detection) + iou_b = box_iou_batch_alt(boxes_true, boxes_detection) + + assert iou_a.shape == iou_b.shape + assert np.allclose(iou_a, iou_b, rtol=1e-6, atol=1e-6) From ceefe0bab2e96c84debd24a1b2a1a234e72817fb Mon Sep 17 00:00:00 2001 From: AnonymDevOSS Date: Thu, 6 Nov 2025 22:52:25 +0100 Subject: [PATCH 2/4] feat/ speed up box iou - replaced original function; added tests --- supervision/detection/utils/iou_and_nms.py | 87 ---------------- test/detection/utils/functions.py | 38 ------- test/detection/utils/test_iou_and_nms.py | 112 +++++++++++++++++++-- test/test_utils.py | 36 +++++++ 4 files changed, 139 insertions(+), 134 deletions(-) delete mode 100644 test/detection/utils/functions.py diff --git a/supervision/detection/utils/iou_and_nms.py b/supervision/detection/utils/iou_and_nms.py index b55eea876..299a61609 100644 --- a/supervision/detection/utils/iou_and_nms.py +++ b/supervision/detection/utils/iou_and_nms.py @@ -172,93 +172,6 @@ def box_iou_batch( `shape = (N, M)` where `N` is number of true objects and `M` is number of detected objects. - Examples: - ```python - import numpy as np - import supervision as sv - - boxes_true = np.array([ - [100, 100, 200, 200], - [300, 300, 400, 400] - ]) - boxes_detection = np.array([ - [150, 150, 250, 250], - [320, 320, 420, 420] - ]) - - sv.box_iou_batch(boxes_true=boxes_true, boxes_detection=boxes_detection) - # array([ - # [0.14285714, 0. ], - # [0. , 0.47058824] - # ]) - ``` - """ - - def box_area(box): - return (box[2] - box[0]) * (box[3] - box[1]) - - area_true = box_area(boxes_true.T) - area_detection = box_area(boxes_detection.T) - - top_left = np.maximum(boxes_true[:, None, :2], boxes_detection[:, :2]) - bottom_right = np.minimum(boxes_true[:, None, 2:], boxes_detection[:, 2:]) - - area_inter = np.prod(np.clip(bottom_right - top_left, a_min=0, a_max=None), 2) - - if overlap_metric == OverlapMetric.IOU: - union_area = area_true[:, None] + area_detection - area_inter - ious = np.divide( - area_inter, - union_area, - out=np.zeros_like(area_inter, dtype=float), - where=union_area != 0, - ) - elif overlap_metric == OverlapMetric.IOS: - small_area = np.minimum(area_true[:, None], area_detection) - ious = np.divide( - area_inter, - small_area, - out=np.zeros_like(area_inter, dtype=float), - where=small_area != 0, - ) - else: - raise ValueError( - f"overlap_metric {overlap_metric} is not supported, " - "only 'IOU' and 'IOS' are supported" - ) - - ious = np.nan_to_num(ious) - return ious - - -def box_iou_batch_alt( - boxes_true: np.ndarray, - boxes_detection: np.ndarray, - overlap_metric: OverlapMetric = OverlapMetric.IOU, -) -> np.ndarray: - """ - Compute Intersection over Union (IoU) of two sets of bounding boxes - - `boxes_true` and `boxes_detection`. Both sets - of boxes are expected to be in `(x_min, y_min, x_max, y_max)` format. - - Note: - Use `box_iou` when computing IoU between two individual boxes. - For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better - performance. - - Args: - boxes_true (np.ndarray): 2D `np.ndarray` representing ground-truth boxes. - `shape = (N, 4)` where `N` is number of true objects. - boxes_detection (np.ndarray): 2D `np.ndarray` representing detection boxes. - `shape = (M, 4)` where `M` is number of detected objects. - overlap_metric (OverlapMetric): Metric used to compute the degree of overlap - between pairs of boxes (e.g., IoU, IoS). - - Returns: - np.ndarray: Pairwise IoU of boxes from `boxes_true` and `boxes_detection`. - `shape = (N, M)` where `N` is number of true objects and - `M` is number of detected objects. - Examples: ```python import numpy as np diff --git a/test/detection/utils/functions.py b/test/detection/utils/functions.py deleted file mode 100644 index 6b10dfa2f..000000000 --- a/test/detection/utils/functions.py +++ /dev/null @@ -1,38 +0,0 @@ -import random - -import numpy as np - - -def generate_boxes( - n: int, - W: int = 1920, - H: int = 1080, - min_size: int = 20, - max_size: int = 200, - seed: int | None = 1, -): - """ - Generate N valid bounding boxes of format [x_min, y_min, x_max, y_max]. - - Args: - n (int): Number of boexs to generate - W (int): Image width - H (int): Image height - min_size (int): Minimum box size (width/height) - max_size (int): Maximum box size (width/height) - seed (int | None): Random seed for reproducibility - - Returns: - list[list[float]] | np.ndarray: List of boxes - """ - random.seed(seed) - boxes = [] - for _ in range(n): - w = random.uniform(min_size, max_size) - h = random.uniform(min_size, max_size) - x1 = random.uniform(0, W - w) - y1 = random.uniform(0, H - h) - x2 = x1 + w - y2 = y1 + h - boxes.append([x1, y1, x2, y2]) - return np.array(boxes, dtype=np.float32) diff --git a/test/detection/utils/test_iou_and_nms.py b/test/detection/utils/test_iou_and_nms.py index 87fd958ad..6536d29c0 100644 --- a/test/detection/utils/test_iou_and_nms.py +++ b/test/detection/utils/test_iou_and_nms.py @@ -6,14 +6,15 @@ import pytest from supervision.detection.utils.iou_and_nms import ( + OverlapMetric, _group_overlapping_boxes, + box_iou, box_iou_batch, - box_iou_batch_alt, box_non_max_suppression, mask_non_max_merge, mask_non_max_suppression, ) -from test.detection.utils.functions import generate_boxes +from test.test_utils import mock_boxes @pytest.mark.parametrize( @@ -636,12 +637,105 @@ def test_mask_non_max_merge( assert sorted_result == sorted_expected_result -def test_box_iou_batch_and_alt_equivalence(): - boxes_true = generate_boxes(20, seed=1) - boxes_detection = generate_boxes(30, seed=2) +@pytest.mark.parametrize( + "boxes_true, boxes_detection, expected_iou, exception", + [ + ( + np.empty((0, 4), dtype=np.float32), + np.empty((0, 4), dtype=np.float32), + np.empty((0, 0), dtype=np.float32), + DoesNotRaise(), + ), # empty + ( + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.empty((0, 4), dtype=np.float32), + np.empty((1, 0), dtype=np.float32), + DoesNotRaise(), + ), # one true box, no detections + ( + np.empty((0, 4), dtype=np.float32), + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.empty((0, 1), dtype=np.float32), + DoesNotRaise(), + ), # no true boxes, one detection + ( + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.array([[1.0]]), + DoesNotRaise(), + ), # perfect overlap + ( + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.array([[20, 20, 30, 30]], dtype=np.float32), + np.array([[0.0]]), + DoesNotRaise(), + ), # no overlap + ( + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.array([[5, 5, 15, 15]], dtype=np.float32), + np.array([[25.0 / 175.0]]), # intersection: 5x5=25, union: 100+100-25=175 + DoesNotRaise(), + ), # partial overlap + ( + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.array([[0, 0, 5, 5]], dtype=np.float32), + np.array([[25.0 / 100.0]]), # intersection: 5x5=25, union: 100 + DoesNotRaise(), + ), # detection inside true box + ( + np.array([[0, 0, 5, 5]], dtype=np.float32), + np.array([[0, 0, 10, 10]], dtype=np.float32), + np.array([[25.0 / 100.0]]), # true box inside detection + DoesNotRaise(), + ), + ( + np.array([[0, 0, 10, 10], [20, 20, 30, 30]], dtype=np.float32), + np.array([[0, 0, 10, 10], [20, 20, 30, 30]], dtype=np.float32), + np.array([[1.0, 0.0], [0.0, 1.0]]), + DoesNotRaise(), + ), # two boxes, perfect matches + ], +) +def test_box_iou_batch( + boxes_true: np.ndarray, + boxes_detection: np.ndarray, + expected_iou: np.ndarray, + exception: Exception, +) -> None: + with exception: + result = box_iou_batch(boxes_true, boxes_detection) + assert result.shape == expected_iou.shape + assert np.allclose(result, expected_iou, rtol=1e-5, atol=1e-5) + + +def test_box_iou_batch_consistency_with_box_iou(): + """Test that box_iou_batch gives same results as box_iou for single boxes.""" + boxes_true = np.array(mock_boxes(5, seed=1), dtype=np.float32) + boxes_detection = np.array(mock_boxes(5, seed=2), dtype=np.float32) + + batch_result = box_iou_batch(boxes_true, boxes_detection) + + for i, box_true in enumerate(boxes_true): + for j, box_detection in enumerate(boxes_detection): + single_result = box_iou(box_true, box_detection) + assert np.allclose( + batch_result[i, j], single_result, rtol=1e-5, atol=1e-5 + ) + + +def test_box_iou_batch_with_mock_detections(): + """ Test box_iou_batch with generated boxes and verify results are valid. """ + boxes_true = np.array(mock_boxes(10, seed=1), dtype=np.float32) + boxes_detection = np.array(mock_boxes(15, seed=2), dtype=np.float32) - iou_a = box_iou_batch(boxes_true, boxes_detection) - iou_b = box_iou_batch_alt(boxes_true, boxes_detection) + result = box_iou_batch(boxes_true, boxes_detection) - assert iou_a.shape == iou_b.shape - assert np.allclose(iou_a, iou_b, rtol=1e-6, atol=1e-6) + assert result.shape == (10, 15) + + assert np.all(result >= 0) + assert np.all(result <= 1.0) + + # and symetric + result_reversed = box_iou_batch(boxes_detection, boxes_true) + assert result_reversed.shape == (15, 10) + assert np.allclose(result.T, result_reversed, rtol=1e-5, atol=1e-5) diff --git a/test/test_utils.py b/test/test_utils.py index 0a97bf4bf..e512de6f6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import random from typing import Any import numpy as np @@ -52,5 +53,40 @@ def convert_data(data: dict[str, list[Any]]): ) +def mock_boxes( + n: int, + resolution_wh: tuple[int, int] = (1920, 1080), + min_size: int = 20, + max_size: int = 200, + seed: int | None = None, +) -> list[list[float]]: + """ + Generate N valid bounding boxes of format [x_min, y_min, x_max, y_max]. + + Args: + n: Number of boxes to generate. + resolution_wh: Image resolution as (width, height). Defaults to (1920, 1080). + min_size: Minimum box size (width/height). Defaults to 20. + max_size: Maximum box size (width/height). Defaults to 200. + seed: Random seed for reproducibility. Defaults to None. + + Returns: + List of boxes, each as [x_min, y_min, x_max, y_max]. + """ + if seed is not None: + random.seed(seed) + width, height = resolution_wh + boxes = [] + for _ in range(n): + w = random.uniform(min_size, max_size) + h = random.uniform(min_size, max_size) + x1 = random.uniform(0, width - w) + y1 = random.uniform(0, height - h) + x2 = x1 + w + y2 = y1 + h + boxes.append([x1, y1, x2, y2]) + return boxes + + def assert_almost_equal(actual, expected, tolerance=1e-5): assert abs(actual - expected) < tolerance, f"Expected {expected}, but got {actual}." From da7c5cd65ab58915793f6c50081d53a3cbeaf956 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Nov 2025 21:57:23 +0000 Subject: [PATCH 3/4] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/detection/utils/test_iou_and_nms.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/detection/utils/test_iou_and_nms.py b/test/detection/utils/test_iou_and_nms.py index 6536d29c0..765bc5a28 100644 --- a/test/detection/utils/test_iou_and_nms.py +++ b/test/detection/utils/test_iou_and_nms.py @@ -6,7 +6,6 @@ import pytest from supervision.detection.utils.iou_and_nms import ( - OverlapMetric, _group_overlapping_boxes, box_iou, box_iou_batch, @@ -718,23 +717,21 @@ def test_box_iou_batch_consistency_with_box_iou(): for i, box_true in enumerate(boxes_true): for j, box_detection in enumerate(boxes_detection): single_result = box_iou(box_true, box_detection) - assert np.allclose( - batch_result[i, j], single_result, rtol=1e-5, atol=1e-5 - ) + assert np.allclose(batch_result[i, j], single_result, rtol=1e-5, atol=1e-5) def test_box_iou_batch_with_mock_detections(): - """ Test box_iou_batch with generated boxes and verify results are valid. """ + """Test box_iou_batch with generated boxes and verify results are valid.""" boxes_true = np.array(mock_boxes(10, seed=1), dtype=np.float32) boxes_detection = np.array(mock_boxes(15, seed=2), dtype=np.float32) result = box_iou_batch(boxes_true, boxes_detection) assert result.shape == (10, 15) - + assert np.all(result >= 0) assert np.all(result <= 1.0) - + # and symetric result_reversed = box_iou_batch(boxes_detection, boxes_true) assert result_reversed.shape == (15, 10) From bb87ab0a6ed6457f73ccc3fd4f5aeb1bb635563c Mon Sep 17 00:00:00 2001 From: AnonymDevOSS Date: Thu, 6 Nov 2025 23:00:33 +0100 Subject: [PATCH 4/4] fixed autoformat errors --- test/detection/utils/test_iou_and_nms.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/test/detection/utils/test_iou_and_nms.py b/test/detection/utils/test_iou_and_nms.py index 6536d29c0..29c11dded 100644 --- a/test/detection/utils/test_iou_and_nms.py +++ b/test/detection/utils/test_iou_and_nms.py @@ -6,7 +6,6 @@ import pytest from supervision.detection.utils.iou_and_nms import ( - OverlapMetric, _group_overlapping_boxes, box_iou, box_iou_batch, @@ -718,24 +717,22 @@ def test_box_iou_batch_consistency_with_box_iou(): for i, box_true in enumerate(boxes_true): for j, box_detection in enumerate(boxes_detection): single_result = box_iou(box_true, box_detection) - assert np.allclose( - batch_result[i, j], single_result, rtol=1e-5, atol=1e-5 - ) + assert np.allclose(batch_result[i, j], single_result, rtol=1e-5, atol=1e-5) def test_box_iou_batch_with_mock_detections(): - """ Test box_iou_batch with generated boxes and verify results are valid. """ + """Test box_iou_batch with generated boxes and verify results are valid.""" boxes_true = np.array(mock_boxes(10, seed=1), dtype=np.float32) boxes_detection = np.array(mock_boxes(15, seed=2), dtype=np.float32) result = box_iou_batch(boxes_true, boxes_detection) assert result.shape == (10, 15) - + assert np.all(result >= 0) assert np.all(result <= 1.0) - - # and symetric + + # and symmetric result_reversed = box_iou_batch(boxes_detection, boxes_true) assert result_reversed.shape == (15, 10) assert np.allclose(result.T, result_reversed, rtol=1e-5, atol=1e-5)