diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index b547737bf5..2168eda99e 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -21,7 +21,7 @@ def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]: for op in registration.default_registry.values(): - for func in (*op.overloads, *op.privates, *op.complex): + for func in (*op.overloads, *op.complex): if isinstance(func, onnxscript.OnnxFunction): yield func diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index be30520878..34e855597b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3782,7 +3782,6 @@ def aten_grid_sampler( padding_mode_options = ("zeros", "border", "reflection") padding_mode_str = padding_mode_options[padding_mode] - # Only one onnx Op so don't put into private function return op.GridSample( input, grid, @@ -3808,7 +3807,6 @@ def aten_grid_sampler_2d( padding_mode_options = ("zeros", "border", "reflection") padding_mode_str = padding_mode_options[padding_mode] - # Only one onnx Op so don't put into private function return op.GridSample( input, grid, @@ -4138,7 +4136,9 @@ def _aten_index_onnx( @torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) -def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: +def aten_index( + self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]] +) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor NOTE: Understanding `aten::index` @@ -4158,14 +4158,19 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy None in `indices` are like fillers for dimensions that cannot be removed in the process. """ + # Handle Boolean indexing first + for index in indices: + if index is None: + continue + if index.dtype == BOOL.dtype: + return _aten_index_bool(self, indices) index_ranks = [len(index.shape) for index in indices if index is not None] return _aten_index_onnx(self, indices, index_ranks) -@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) -def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements +def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements index_ranks = [len(index.shape) for index in indices if index is not None] if index_ranks[0] == 1: @@ -4200,9 +4205,9 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens finla_rank = input_rank - (len(index.shape) - 1) trans_perm = list(range(finla_rank)) trans_perm = trans_perm[-1:] + trans_perm[:-1] - for _ in range(count_of_none): - result = op.Transpose(result, perm=trans_perm) - return result + for _ in range(count_of_none): + result = op.Transpose(result, perm=trans_perm) + return result def aten_index_add( @@ -4224,7 +4229,7 @@ def aten_index_copy( @torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, - indices: Sequence[INT64], + indices: Sequence[Optional[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: @@ -4233,6 +4238,12 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ + # Handle Boolean indexing first + for index in indices: + if index is None: + continue + if index.dtype == BOOL.dtype: + return _aten_index_put_bool(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. @@ -4310,8 +4321,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): return result -@torch_op("aten::index_put", trace_only=True) -def aten_index_put_bool( +def _aten_index_put_bool( self: TReal, indices: Sequence[BOOL], values: TReal, diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 5edcc233d0..f58b2c7f18 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -328,7 +328,6 @@ def aten_col2im( else: # assert len(padding) == 4, already [w, x, y, z] pads = padding - # Only one ONNX op here so didn't write a private function return op.Col2Im( self, output_size, diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 162d69d747..e4000bc55e 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -22,14 +22,12 @@ class OverloadedFunction: Attributes: name: Name of the op. E.g. "aten::add". overloads: Overloads function. - privates: Private functions not exposed to users. complex: Support complex functions. """ def __init__(self, name: str): self.name = name self.overloads: list[Any] = [] - self.privates: list[Any] = [] self.complex: list[Any] = [] @@ -39,17 +37,22 @@ class Registry: def __init__(self): self._registry: dict[str, OverloadedFunction] = {} - def register( - self, func: Any, name: str, *, private: bool = False, complex: bool = False - ) -> None: + def register(self, func: Any, name: str, *, complex: bool = False) -> None: """Register a function.""" - - if private: - self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func) - elif complex: - self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func) + overloaded_function = self._registry.setdefault(name, OverloadedFunction(name)) + + if complex: + if overloaded_function.complex: + raise ValueError( + f"Complex overload for '{name}' already registered: {overloaded_function.complex}." + ) + overloaded_function.complex.append(func) else: - self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func) + if overloaded_function.overloads: + raise ValueError( + f"Real overload for '{name}' already registered: {overloaded_function.overloads}." + ) + overloaded_function.overloads.append(func) def __getitem__(self, name): return self._registry[name] @@ -131,7 +134,9 @@ def wrapper( assert registry is not None for name_ in _check_and_normalize_names(name): - registry.register(processed_func, name_, private=private, complex=complex) + if private: + continue + registry.register(processed_func, name_, complex=complex) return processed_func return wrapper diff --git a/pyproject.toml b/pyproject.toml index 4f7edc9bf8..315cc8a6ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,11 @@ onnx = ["py.typed"] [tool.pytest.ini_options] addopts = "-rsfEX --tb=short --color=yes" +norecursedirs = [ + # Skip test collection because pytest will try to import the modules twice, + # causing the torchlib registry to complain that functions are redefined. + "onnxscript/function_libs/torch_lib/ops", +] [tool.mypy] # TODO disallow_incomplete_defs = true diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf31..f68a9f7d34 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -721,23 +721,10 @@ def _where_input_wrangler( # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool), - TorchLibOpInfo( - "index_put_bool", - core_ops.aten_index_put_bool, - input_wrangler=_index_put_input_wrangler, - ).skip( - matcher=lambda sample: sample.args[0][0].dtype != torch.bool, - reason="this Aten overload only supports tensor(bool) as indices", - ), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), TorchLibOpInfo( "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler - ) - .skip( - matcher=lambda sample: sample.args[0][0].dtype != torch.int64, - reason="this Aten overload only supports tensor(int) as indices", - ) - .xfail( + ).skip( dtypes=(torch.float16,), matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", @@ -1806,7 +1793,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) -ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))