From ce7672264e36abf48b37d9a520720ebe1c58cf43 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:31:15 -0700 Subject: [PATCH 1/7] Consolidate all overloads and prevent new ones from being created Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 20 +++++++++++++------ .../function_libs/torch_lib/ops_test_data.py | 18 ++--------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e837bfadae..162672696e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4060,7 +4060,9 @@ def _aten_index_onnx( @torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) -def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: +def aten_index( + self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]] +) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor NOTE: Understanding `aten::index` @@ -4080,14 +4082,17 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy None in `indices` are like fillers for dimensions that cannot be removed in the process. """ + # Handle Boolean indexing first + for index in indices: + if index is not None and index.dtype == BOOL.dtype: + return _aten_index_bool(self, indices) index_ranks = [len(index.shape) for index in indices if index is not None] return _aten_index_onnx(self, indices, index_ranks) -@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) -def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements +def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements index_ranks = [len(index.shape) for index in indices if index is not None] if index_ranks[0] == 1: @@ -4146,7 +4151,7 @@ def aten_index_copy( @torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, - indices: Sequence[INT64], + indices: Sequence[[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: @@ -4155,6 +4160,10 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ + # Handle Boolean indexing first + for index in indices: + if index is not None and index.dtype == BOOL.dtype: + return _aten_index_put_bool(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. @@ -4232,8 +4241,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): return result -@torch_op("aten::index_put", trace_only=True) -def aten_index_put_bool( +def _aten_index_put_bool( self: TReal, indices: Sequence[BOOL], values: TReal, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf31..4a36da2d67 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -721,23 +721,10 @@ def _where_input_wrangler( # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool), - TorchLibOpInfo( - "index_put_bool", - core_ops.aten_index_put_bool, - input_wrangler=_index_put_input_wrangler, - ).skip( - matcher=lambda sample: sample.args[0][0].dtype != torch.bool, - reason="this Aten overload only supports tensor(bool) as indices", - ), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), TorchLibOpInfo( "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler - ) - .skip( - matcher=lambda sample: sample.args[0][0].dtype != torch.int64, - reason="this Aten overload only supports tensor(int) as indices", - ) - .xfail( + ).xfail( dtypes=(torch.float16,), matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", @@ -1806,7 +1793,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) -ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) From 477d9fbb18335620502fefb48a5a641632e59211 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:36:30 -0700 Subject: [PATCH 2/7] Remove registration of private functions Signed-off-by: Justin Chu --- .../torch_lib/deduce_type_constraints_test.py | 2 +- .../function_libs/torch_lib/ops/core.py | 2 -- onnxscript/function_libs/torch_lib/ops/nn.py | 1 - .../function_libs/torch_lib/registration.py | 23 ++++++++++--------- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a8d15c242a..a2db474acc 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -21,7 +21,7 @@ def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]: for op in registration.default_registry.values(): - for func in (*op.overloads, *op.privates, *op.complex): + for func in (*op.overloads, *op.complex): if isinstance(func, onnxscript.OnnxFunction): yield func diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 162672696e..4612c01c27 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3704,7 +3704,6 @@ def aten_grid_sampler( padding_mode_options = ("zeros", "border", "reflection") padding_mode_str = padding_mode_options[padding_mode] - # Only one onnx Op so don't put into private function return op.GridSample( input, grid, @@ -3730,7 +3729,6 @@ def aten_grid_sampler_2d( padding_mode_options = ("zeros", "border", "reflection") padding_mode_str = padding_mode_options[padding_mode] - # Only one onnx Op so don't put into private function return op.GridSample( input, grid, diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 2a7a46ec28..ab733a7b46 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -330,7 +330,6 @@ def aten_col2im( else: # assert len(padding) == 4, already [w, x, y, z] pads = padding - # Only one ONNX op here so didn't write a private function return op.Col2Im( self, output_size, diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 162d69d747..c7c0a39634 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -22,14 +22,12 @@ class OverloadedFunction: Attributes: name: Name of the op. E.g. "aten::add". overloads: Overloads function. - privates: Private functions not exposed to users. complex: Support complex functions. """ def __init__(self, name: str): self.name = name self.overloads: list[Any] = [] - self.privates: list[Any] = [] self.complex: list[Any] = [] @@ -39,17 +37,18 @@ class Registry: def __init__(self): self._registry: dict[str, OverloadedFunction] = {} - def register( - self, func: Any, name: str, *, private: bool = False, complex: bool = False - ) -> None: + def register(self, func: Any, name: str, *, complex: bool = False) -> None: """Register a function.""" + overloaded_function = self._registry.setdefault(name, OverloadedFunction(name)) - if private: - self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func) - elif complex: - self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func) + if complex: + if overloaded_function.complex: + raise ValueError(f"Complex overload for '{name}' already registered.") + overloaded_function.complex.append(func) else: - self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func) + if overloaded_function.overloads: + raise ValueError(f"Real overload for '{name}' already registered.") + overloaded_function.overloads.append(func) def __getitem__(self, name): return self._registry[name] @@ -131,7 +130,9 @@ def wrapper( assert registry is not None for name_ in _check_and_normalize_names(name): - registry.register(processed_func, name_, private=private, complex=complex) + if private: + continue + registry.register(processed_func, name_, complex=complex) return processed_func return wrapper From f704a3a0e5b24d041bb464a205c18de8847127cd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:41:19 -0700 Subject: [PATCH 3/7] fix typing Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4612c01c27..d37e149524 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4082,7 +4082,9 @@ def aten_index( """ # Handle Boolean indexing first for index in indices: - if index is not None and index.dtype == BOOL.dtype: + if index is None: + continue + if index.dtype == BOOL.dtype: return _aten_index_bool(self, indices) index_ranks = [len(index.shape) for index in indices if index is not None] @@ -4149,7 +4151,7 @@ def aten_index_copy( @torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, - indices: Sequence[[Union[INT64, BOOL]]], + indices: Sequence[Optional[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: @@ -4160,7 +4162,9 @@ def aten_index_put( """ # Handle Boolean indexing first for index in indices: - if index is not None and index.dtype == BOOL.dtype: + if index is None: + continue + if index.dtype == BOOL.dtype: return _aten_index_put_bool(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): From 8f69c8daead1d0e7bdbc84879d529eb84ebf2f75 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:43:58 -0700 Subject: [PATCH 4/7] test Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4a36da2d67..f68a9f7d34 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -724,7 +724,7 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), TorchLibOpInfo( "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler - ).xfail( + ).skip( dtypes=(torch.float16,), matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", From 15d3c45db6bfcf697ac64ff4cb6ff9647a5b9c64 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:45:43 -0700 Subject: [PATCH 5/7] msg Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/registration.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index c7c0a39634..e4000bc55e 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -43,11 +43,15 @@ def register(self, func: Any, name: str, *, complex: bool = False) -> None: if complex: if overloaded_function.complex: - raise ValueError(f"Complex overload for '{name}' already registered.") + raise ValueError( + f"Complex overload for '{name}' already registered: {overloaded_function.complex}." + ) overloaded_function.complex.append(func) else: if overloaded_function.overloads: - raise ValueError(f"Real overload for '{name}' already registered.") + raise ValueError( + f"Real overload for '{name}' already registered: {overloaded_function.overloads}." + ) overloaded_function.overloads.append(func) def __getitem__(self, name): From 8515502e94e554b0221ba55daf9dad638624fa39 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 22:15:23 -0700 Subject: [PATCH 6/7] skip pytest Signed-off-by: Justin Chu --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4f7edc9bf8..315cc8a6ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,11 @@ onnx = ["py.typed"] [tool.pytest.ini_options] addopts = "-rsfEX --tb=short --color=yes" +norecursedirs = [ + # Skip test collection because pytest will try to import the modules twice, + # causing the torchlib registry to complain that functions are redefined. + "onnxscript/function_libs/torch_lib/ops", +] [tool.mypy] # TODO disallow_incomplete_defs = true From 90c3d0227f45cf5d7b9e5d1ea282187f8fa29945 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 22:19:37 -0700 Subject: [PATCH 7/7] index_bool is wrong Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9ae8eb1c43..539947bbbd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4201,9 +4201,9 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten finla_rank = input_rank - (len(index.shape) - 1) trans_perm = list(range(finla_rank)) trans_perm = trans_perm[-1:] + trans_perm[:-1] - for _ in range(count_of_none): - result = op.Transpose(result, perm=trans_perm) - return result + for _ in range(count_of_none): + result = op.Transpose(result, perm=trans_perm) + return result def aten_index_add(