From 7ea3ed938af84ac53a1cd534b4f318337dbb40fb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 23 Sep 2025 07:12:36 -0700 Subject: [PATCH 1/2] [torchlib] Migrate torchvision implementations Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/vision.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 57e2f8bd0c..91770f375a 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -7,6 +7,8 @@ from __future__ import annotations +import warnings + from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import FLOAT, INT64 @@ -23,3 +25,66 @@ def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64: # nms_out: [num_selected_indices, 3] where each column is [batch_index, class_index, box_index] nms_out = op.NonMaxSuppression(boxes, scores, _INT64_MAX, iou_threshold) return op.Reshape(op.Slice(nms_out, axes=[1], starts=[2], ends=[3]), [-1]) + + +def _process_batch_indices_for_roi_align(rois): + # Extract batch indices from the first column (index 0) of rois + indices = op.Slice(rois, axes=[1], starts=[0], ends=[1]) + indices = op.Squeeze(indices, axes=[1]) + return op.Cast(indices, to=INT64.dtype) + + +def _process_rois_for_roi_align(rois): + # Extract roi coordinates from columns 1, 2, 3, 4 (x1, y1, x2, y2) + return op.Slice(rois, axes=[1], starts=[1], ends=[5]) + + +def _process_sampling_ratio_for_roi_align(sampling_ratio: int): + if sampling_ratio < 0: + warnings.warn( + "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. " + "The model will be exported with a sampling_ratio of 0.", + stacklevel=2, + ) + sampling_ratio = 0 + return sampling_ratio + + +@torch_op("torchvision::roi_align") +def torchvision_roi_align( + input, + rois, + spatial_scale: float, + pooled_height: int, + pooled_width: int, + sampling_ratio: int = -1, + aligned: bool = False, +): + batch_indices = _process_batch_indices_for_roi_align(rois) + rois_coords = _process_rois_for_roi_align(rois) + coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" + sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio) + + return op.RoiAlign( + input, + rois_coords, + batch_indices, + coordinate_transformation_mode=coordinate_transformation_mode, + spatial_scale=spatial_scale, + output_height=pooled_height, + output_width=pooled_width, + sampling_ratio=sampling_ratio, + ) + + +@torch_op("torchvision::roi_pool") +def torchvision_roi_pool( + input, rois, spatial_scale: float, pooled_height: int, pooled_width: int +): + # MaxRoiPool expects rois in format [batch_index, x1, y1, x2, y2] + return op.MaxRoiPool( + input, + rois, + pooled_shape=(pooled_height, pooled_width), + spatial_scale=spatial_scale, + ) From 4f5d8c34043d7b357d5b2d728c551256a5d91151 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 23 Sep 2025 11:42:12 -0700 Subject: [PATCH 2/2] trace Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/vision.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 91770f375a..ec30cce265 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -16,7 +16,7 @@ _INT64_MAX = 0x7FFFFFFFFFFFFFFF -@torch_op("torchvision::nms") +@torch_op("torchvision::nms", trace_only=True) def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64: # boxes: [num_batches, spatial_dimension, 4] boxes = op.Unsqueeze(boxes, [0]) @@ -50,7 +50,7 @@ def _process_sampling_ratio_for_roi_align(sampling_ratio: int): return sampling_ratio -@torch_op("torchvision::roi_align") +@torch_op("torchvision::roi_align", trace_only=True) def torchvision_roi_align( input, rois, @@ -77,7 +77,7 @@ def torchvision_roi_align( ) -@torch_op("torchvision::roi_pool") +@torch_op("torchvision::roi_pool", trace_only=True) def torchvision_roi_pool( input, rois, spatial_scale: float, pooled_height: int, pooled_width: int ):