Skip to content

Commit a70ee8d

Browse files
authored
Use ir.val to replace ir.Input (#2556)
Use ir.val to replace ir.Input because ir.Input was deprecated --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 92633a6 commit a70ee8d

File tree

7 files changed

+20
-171
lines changed

7 files changed

+20
-171
lines changed

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"packaging",
4343
"protobuf",
4444
)
45-
ONNX_IR = "onnx_ir==0.1.7"
45+
ONNX_IR = "onnx_ir==0.1.9"
4646
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"
4747

4848

onnxscript/ir/__init__.py

Lines changed: 2 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,154 +1,4 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""In-memory intermediate representation for ONNX graphs."""
4-
5-
__all__ = [
6-
# Modules
7-
"serde",
8-
"traversal",
9-
"convenience",
10-
"external_data",
11-
"tape",
12-
# IR classes
13-
"Tensor",
14-
"ExternalTensor",
15-
"StringTensor",
16-
"LazyTensor",
17-
"SymbolicDim",
18-
"Shape",
19-
"TensorType",
20-
"OptionalType",
21-
"SequenceType",
22-
"SparseTensorType",
23-
"TypeAndShape",
24-
"Value",
25-
"Attr",
26-
"RefAttr",
27-
"Node",
28-
"Function",
29-
"Graph",
30-
"GraphView",
31-
"Model",
32-
# Constructors
33-
"AttrFloat32",
34-
"AttrFloat32s",
35-
"AttrGraph",
36-
"AttrGraphs",
37-
"AttrInt64",
38-
"AttrInt64s",
39-
"AttrSparseTensor",
40-
"AttrSparseTensors",
41-
"AttrString",
42-
"AttrStrings",
43-
"AttrTensor",
44-
"AttrTensors",
45-
"AttrTypeProto",
46-
"AttrTypeProtos",
47-
"Input",
48-
# Protocols
49-
"ArrayCompatible",
50-
"DLPackCompatible",
51-
"TensorProtocol",
52-
"ValueProtocol",
53-
"ModelProtocol",
54-
"NodeProtocol",
55-
"GraphProtocol",
56-
"GraphViewProtocol",
57-
"AttributeProtocol",
58-
"ReferenceAttributeProtocol",
59-
"SparseTensorProtocol",
60-
"SymbolicDimProtocol",
61-
"ShapeProtocol",
62-
"TypeProtocol",
63-
"MapTypeProtocol",
64-
"FunctionProtocol",
65-
# Enums
66-
"AttributeType",
67-
"DataType",
68-
# Types
69-
"OperatorIdentifier",
70-
# Protobuf compatible types
71-
"TensorProtoTensor",
72-
# Conversion functions
73-
"from_proto",
74-
"from_onnx_text",
75-
"to_proto",
76-
# Convenience constructors
77-
"tensor",
78-
"node",
79-
# Pass infrastructure
80-
"passes",
81-
# IO
82-
"load",
83-
"save",
84-
]
85-
86-
from onnx_ir import (
87-
ArrayCompatible,
88-
Attr,
89-
AttrFloat32,
90-
AttrFloat32s,
91-
AttrGraph,
92-
AttrGraphs,
93-
AttributeProtocol,
94-
AttributeType,
95-
AttrInt64,
96-
AttrInt64s,
97-
AttrSparseTensor,
98-
AttrSparseTensors,
99-
AttrString,
100-
AttrStrings,
101-
AttrTensor,
102-
AttrTensors,
103-
AttrTypeProto,
104-
AttrTypeProtos,
105-
DataType,
106-
DLPackCompatible,
107-
ExternalTensor,
108-
Function,
109-
FunctionProtocol,
110-
Graph,
111-
GraphProtocol,
112-
GraphView,
113-
GraphViewProtocol,
114-
Input,
115-
LazyTensor,
116-
MapTypeProtocol,
117-
Model,
118-
ModelProtocol,
119-
Node,
120-
NodeProtocol,
121-
OperatorIdentifier,
122-
OptionalType,
123-
RefAttr,
124-
ReferenceAttributeProtocol,
125-
SequenceType,
126-
Shape,
127-
ShapeProtocol,
128-
SparseTensorProtocol,
129-
SparseTensorType,
130-
StringTensor,
131-
SymbolicDim,
132-
SymbolicDimProtocol,
133-
Tensor,
134-
TensorProtocol,
135-
TensorProtoTensor,
136-
TensorType,
137-
TypeAndShape,
138-
TypeProtocol,
139-
Value,
140-
ValueProtocol,
141-
convenience,
142-
external_data,
143-
from_onnx_text,
144-
from_proto,
145-
load,
146-
node,
147-
passes,
148-
save,
149-
serde,
150-
tape,
151-
tensor,
152-
to_proto,
153-
traversal,
154-
)
3+
# pylint: disable=wildcard-import,unused-wildcard-import
4+
from onnx_ir import * # type: ignore # noqa: F403

onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
class Bfloat16ConversionTest(unittest.TestCase):
1616
def setUp(self) -> None:
17-
self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4]))
17+
self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4]))
1818
self.v0.dtype = ir.DataType.BFLOAT16
19-
self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4]))
19+
self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4]))
2020
self.v1.dtype = ir.DataType.BFLOAT16
21-
self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4]))
21+
self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4]))
2222
self.v2.dtype = ir.DataType.BFLOAT16
2323

2424
self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)

onnxscript/rewriter/rules/common/_basic_rules_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,14 +421,14 @@ def _convert_shape(shape, name):
421421
if isinstance(shape, np.ndarray):
422422
shape = tape.initializer(ir.Tensor(shape, name=name))
423423
elif isinstance(shape, (list, tuple)):
424-
shape = ir.Input(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64))
424+
shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape))
425425
tape.graph_like.inputs.append(shape)
426426
else:
427427
raise TypeError(f"Unsupported type {type(shape)} for shape.")
428428
return shape
429429

430-
x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
431-
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
430+
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
431+
y = ir.val("Y", ir.DataType.FLOAT)
432432
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
433433

434434
# Build the graph.
@@ -554,8 +554,8 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
554554
class Flatten2ReshapeTest(unittest.TestCase):
555555
@staticmethod
556556
def create_model(input_shape, axis=1):
557-
x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
558-
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
557+
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
558+
y = ir.val("Y", ir.DataType.FLOAT)
559559
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
560560

561561
# Build the graph.

onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ def build_model(
6161

6262
# Register operations in the tape
6363
idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT
64-
x = ir.Input("X", shape=input_shape, type=ir.TensorType(idtype))
64+
x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype))
6565
y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes)
6666
y = tape.op(
6767
op_type,
6868
inputs=[y, self.get_conv_weights(weight_shape, tape)],
6969
attributes=conv_attributes,
70-
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
70+
output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
7171
)
7272
if op_type == "ConvInteger":
7373
y.dtype = ir.DataType.INT32
@@ -290,12 +290,12 @@ def build_model(
290290
raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.")
291291

292292
# Register operations in the tape
293-
x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
293+
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
294294
y = tape.op(
295295
"Conv",
296296
inputs=[x, *conv_inputs],
297297
attributes=conv_attributes,
298-
output=ir.Input("Y", shape=output_shape, type=x.type),
298+
output=ir.val("Y", shape=output_shape, type=x.type),
299299
)
300300

301301
# Build the model

onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def get_test_model(
4646
bias_shape = weight_shape[0] if transB else weight_shape[-1]
4747
output_shape = ir.Shape(("?",) * input_shape.rank())
4848

49-
x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
49+
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
5050

5151
if weight_as_inputs:
52-
w = ir.Input("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
52+
w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
5353
inputs.append(w)
5454
else:
5555
w = ir.tensor(
@@ -58,7 +58,7 @@ def get_test_model(
5858
w = tape.initializer(w)
5959

6060
if bias_as_inputs:
61-
b = ir.Input(
61+
b = ir.val(
6262
"B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT)
6363
)
6464
inputs.append(b)
@@ -77,7 +77,7 @@ def get_test_model(
7777
y = tape.op(
7878
"Add",
7979
inputs=[y, b],
80-
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
80+
output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
8181
)
8282

8383
# Build the model

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"ml_dtypes",
2929
"numpy",
30-
"onnx_ir>=0.1.7,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
30+
"onnx_ir>=0.1.9,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
3131
"onnx>=1.16",
3232
"packaging",
3333
"typing_extensions>=4.10",
@@ -41,7 +41,6 @@ onnxscript = ["py.typed"]
4141
onnx = ["py.typed"]
4242

4343
[tool.pytest.ini_options]
44-
filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"]
4544
addopts = "-rsfEX --tb=short --color=yes"
4645

4746
[tool.mypy]

0 commit comments

Comments
 (0)