Skip to content

Commit 32a61f4

Browse files
authored
[torchlib] Deprecate Rank and IsScalar (#2624)
Deprecate Rank and IsScalar and remove all usages. Do not remove the definitions because older versions of PyTorch assumes their existance. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 071ff1e commit 32a61f4

File tree

4 files changed

+39
-68
lines changed

4 files changed

+39
-68
lines changed

onnxscript/function_libs/torch_lib/ops/common.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,24 @@
2828

2929
@onnxscript.script(common_opset)
3030
def Rank(input: tensor_typing.TTensor) -> INT64:
31-
"""Take the rank of the input tensor."""
31+
"""Deprecated.
32+
33+
NOTE: Do not remove, for backward compatibility with PyTorch < 2.10.
34+
35+
Take the rank of the input tensor.
36+
"""
3237

3338
return op.Size(op.Shape(input))
3439

3540

3641
@onnxscript.script(common_opset)
3742
def IsScalar(input: tensor_typing.TTensor) -> BOOL:
38-
"""Return whether the input has rank 0, or is a scalar."""
43+
"""Deprecated.
44+
45+
NOTE: Do not remove, for backward compatibility with PyTorch < 2.10.
46+
47+
Return whether the input has rank 0, or is a scalar.
48+
"""
3949

4050
return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))
4151

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
_INT64_MAX = 9223372036854775807
5555
_INT64_MIN = -9223372036854775808
5656
_MATH_PI = math.pi
57-
Rank = common_ops.Rank
5857

5958

6059
@torch_op("aten::_local_scalar_dense", trace_only=True)
@@ -947,11 +946,11 @@ def reshape_to_1d(tensor):
947946
return op.SequenceMap(self, body=reshape_to_1d)
948947

949948

950-
@torch_op("aten::atleast_2d")
949+
@torch_op("aten::atleast_2d", trace_only=True)
951950
def aten_atleast_2d(self: TTensor) -> TTensor:
952951
"""atleast_2d(Tensor self) -> Tensor"""
953952

954-
if Rank(self) <= 1:
953+
if len(self.shape) <= 1:
955954
self = op.Reshape(self, op.Constant(value_ints=[1, -1]))
956955
return op.Identity(self)
957956

@@ -975,7 +974,7 @@ def reshape_to_2d(tensor):
975974
def aten_atleast_3d(self: TTensor) -> TTensor:
976975
"""atleast_3d(Tensor self) -> Tensor"""
977976

978-
rank = Rank(self)
977+
rank = len(self.shape)
979978
if rank <= 1:
980979
self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1]))
981980
elif rank == 2:
@@ -1820,39 +1819,21 @@ def aten_conj_physical(self: TensorType) -> TensorType:
18201819
raise NotImplementedError()
18211820

18221821

1823-
@torch_op("aten::constant_pad_nd")
1824-
def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTensor:
1822+
@torch_op("aten::constant_pad_nd", trace_only=True)
1823+
def aten_constant_pad_nd(self: TTensor, pad: Sequence[INT64], value: float = 0.0) -> TTensor:
18251824
"""constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor"""
18261825

18271826
# The desired order of paddings is
18281827
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
18291828
# n is the dimension of input.
18301829
# assume zero-dimensions in the beginning
1831-
# rank = len(self.shape) # rank must be scalar
1832-
# paddings = list(pad[:]) + [0] * (rank * 2 - len(pad))
1830+
rank = len(self.shape)
1831+
paddings = list(pad) + [0] * (rank * 2 - len(pad))
18331832
# reverse order and collate first beginnings and then ends
1834-
# paddings = paddings[-2::-2] + paddings[-1::-2]
1835-
1836-
neg_1 = op.Constant(value_ints=[-1])
1837-
1838-
zero_count = op.Sub(op.Mul(Rank(self), 2), op.Size(pad))
1839-
zero_count = op.Reshape(zero_count, neg_1)
1840-
zero = op.Constant(value_ints=[0])
1841-
zeros = op.Expand(zero, zero_count)
1842-
torch_paddings = op.Concat(pad, zeros, axis=0)
1843-
size_d = op.Size(torch_paddings)
1844-
steps = op.Constant(value_ints=[-2])
1845-
1846-
starts = steps
1847-
ends = op.Sub(starts, size_d)
1848-
odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1849-
1850-
starts = neg_1
1851-
ends = op.Sub(starts, size_d)
1852-
even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
1833+
paddings = paddings[-2::-2] + paddings[-1::-2]
1834+
constant_value = op.Constant(value=ir.tensor(value, dtype=self.dtype))
18531835

1854-
onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
1855-
return op.Pad(self, onnx_padding, value)
1836+
return op.Pad(self, paddings, constant_value)
18561837

18571838

18581839
@torch_op("aten::contiguous", trace_only=True)
@@ -3996,7 +3977,7 @@ def reshape_to_atleast_2d(tensor):
39963977
result = op.ConcatFromSequence(tensors_atleast_2d, axis=1, new_axis=0)
39973978

39983979
# hstack expects a non-empty sequence of tensors. So we don't need to check for length
3999-
rank_1d_or_less = op.Less(Rank(op.SequenceAt(tensors, 0)), 2)
3980+
rank_1d_or_less = op.Less(op.Size(op.Shape(op.SequenceAt(tensors, 0))), 2)
40003981
if rank_1d_or_less:
40013982
result = op.Reshape(result, op.Constant(value_ints=[-1]))
40023983
return result
@@ -6076,7 +6057,7 @@ def aten_native_group_norm(
60766057
norm = op.Reshape(norm, op.Shape(input), allowzero=True)
60776058
# Using the input weight and bias to do affine
60786059
# But need to unsqueeze to the target shape for broading cast easy
6079-
input_rank = Rank(input)
6060+
input_rank = len(input.shape)
60806061
axes_unsqueeze = op.Range(1, input_rank - 1, 1)
60816062
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
60826063
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
@@ -8229,7 +8210,7 @@ def aten_symeig(
82298210
def aten_t(self: TTensor) -> TTensor:
82308211
"""t(Tensor(a) self) -> Tensor(a)"""
82318212

8232-
rank = Rank(self)
8213+
rank = len(self.shape)
82338214
if rank == 2:
82348215
result = op.Transpose(self, perm=[1, 0])
82358216
else:
@@ -8312,26 +8293,24 @@ def aten_threshold_backward(
83128293
raise NotImplementedError()
83138294

83148295

8315-
@torch_op("aten::tile")
8316-
def aten_tile(self: TTensor, dims: INT64) -> TTensor:
8296+
@torch_op("aten::tile", trace_only=True)
8297+
def aten_tile(self: TTensor, dims: Sequence[int]) -> TTensor:
83178298
"""tile(Tensor self, int[] dims) -> Tensor"""
83188299

8319-
self_rank = Rank(self)
8320-
dims_rank = op.Size(dims)
8321-
diff = op.Sub(self_rank, dims_rank)
8300+
self_rank = len(self.shape)
8301+
dims_rank = len(dims)
8302+
diff = self_rank - dims_rank
83228303

83238304
if diff > 0:
83248305
# dims is shorter than self.shape
83258306
# pad dims with 1
8326-
diff_1d = op.Reshape(diff, op.Constant(value_ints=[1]))
8327-
exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d)
8328-
dims = op.Concat(exapnd_ones, dims, axis=0)
8307+
exapnd_ones = [1] * diff
8308+
dims = [*exapnd_ones, *dims]
83298309

8330-
if diff < 0:
8310+
elif diff < 0:
83318311
# dims is longer than self.shape
83328312
# pad self.shape with 1
8333-
diff_1d = op.Reshape(op.Abs(diff), op.Constant(value_ints=[1]))
8334-
exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d)
8313+
exapnd_ones = op.Constant(value_ints=[1] * (-diff))
83358314
self_shape = op.Shape(self)
83368315
self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0)
83378316
self = op.Reshape(self, self_final_shape, allowzero=True)

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Optional, Sequence, Tuple, TypeVar, Union
1919

2020
from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
21-
from onnxscript.function_libs.torch_lib.ops import common as common_ops
2221
from onnxscript.function_libs.torch_lib.registration import torch_op
2322
from onnxscript.function_libs.torch_lib.tensor_typing import (
2423
IntType,
@@ -32,7 +31,6 @@
3231
from onnxscript.onnx_types import TensorType
3332

3433
_MATH_PI = math.pi
35-
Rank = common_ops.Rank
3634

3735
_INT64_MAX = 9223372036854775807
3836
_INT64_MIN = -9223372036854775808
@@ -576,7 +574,7 @@ def aten_group_norm(
576574
norm = op.Reshape(norm, op.Shape(input))
577575
# Using the input weight and bias to do affine
578576
# But need to unsqueeze to the target shape for broading cast easy
579-
input_rank = Rank(input)
577+
input_rank = len(input.shape)
580578
one = op.Constant(value_int=1)
581579
axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one)
582580
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
@@ -999,7 +997,7 @@ def _aten_max_pool_onnx(
999997
ceil_mode: bool,
1000998
unbatched_rank: int,
1001999
) -> TFloatOrUInt8:
1002-
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
1000+
self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank
10031001
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
10041002
self = op.Unsqueeze(self, [0])
10051003

@@ -1133,7 +1131,7 @@ def _aten_max_pool_with_indices_onnx(
11331131
n_dims_zero: Sequence[int],
11341132
n_dims_axes: Sequence[int],
11351133
) -> Tuple[TFloatOrUInt8, INT64]:
1136-
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
1134+
self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank
11371135
if self_rank_is_unbatched_rank:
11381136
self = op.Unsqueeze(self, axes=[0])
11391137

@@ -1362,11 +1360,11 @@ def aten_nll_loss(
13621360
) -> TFloat:
13631361
"""nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"""
13641362

1365-
self_rank_is_1 = Rank(self) == 1
1363+
self_rank_is_1 = len(self.shape) == 1
13661364
if self_rank_is_1: # self rank should be at least 2
13671365
self = op.Unsqueeze(self, [0])
13681366

1369-
rank_target = Rank(target)
1367+
rank_target = len(target.shape)
13701368
if rank_target == 0: # target rank should be at least 1
13711369
target = op.Unsqueeze(target, [0])
13721370

tests/function_libs/torch_lib/ops_test_common.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import numpy as np
2828
import onnx
29-
import onnx_ir.passes.common as common_passes
3029
import onnxruntime as ort
3130
import onnxruntime.capi.onnxruntime_pybind11_state
3231
import pytest
@@ -37,7 +36,6 @@
3736
import onnxscript
3837
import onnxscript.evaluator
3938
from onnxscript import ir
40-
from onnxscript.function_libs.torch_lib.ops import common as common_ops
4139
from tests.function_libs.torch_lib import error_reproduction
4240

4341
T = TypeVar("T")
@@ -412,19 +410,6 @@ def _format_model_and_input_information(onnx_model, inputs):
412410
}
413411

414412

415-
def add_torchlib_common_imports(model: ir.Model) -> None:
416-
"""Hack to add torchlib common imports to the model."""
417-
418-
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
419-
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
420-
is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto())
421-
model.functions[rank_func.identifier()] = rank_func
422-
model.functions[is_scalar_func.identifier()] = is_scalar_func
423-
removal_pass = common_passes.RemoveUnusedFunctionsPass()
424-
assert removal_pass.in_place
425-
removal_pass(model)
426-
427-
428413
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
429414
"""Checks if the dtype is compatible with the schema.
430415
@@ -593,7 +578,6 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
593578
proto = onnxscript_function.to_function_proto()
594579
ir_function = ir.serde.deserialize_function(proto)
595580
onnx_model.functions[identifier] = ir_function
596-
add_torchlib_common_imports(onnx_model)
597581
# Make sure the model is valid
598582
model_proto = ir.to_proto(onnx_model)
599583
try:

0 commit comments

Comments
 (0)