Skip to content

Commit 9e0366c

Browse files
authored
Create initializers not constant nodes in constant folding pass (#2650)
Partially From #2598 This provides a better optimized graph after constant folding in terms of the number of nodes, which is better for debugging.
1 parent 45b5189 commit 9e0366c

File tree

4 files changed

+60
-45
lines changed

4 files changed

+60
-45
lines changed

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"packaging",
4242
"protobuf",
4343
)
44-
ONNX_IR = "onnx_ir==0.1.10"
44+
ONNX_IR = "onnx_ir==0.1.12"
4545
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"
4646

4747

onnxscript/optimizer/_constant_folding.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,24 +1039,29 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10391039
e,
10401040
)
10411041

1042-
def new_constant(self, node: ir.Node, value) -> ir.Node | None:
1043-
irvalue = node.outputs[0]
1044-
if not isinstance(value, np.ndarray):
1042+
def new_initializer(self, node: ir.Node, array) -> ir.Value | None:
1043+
original_value = node.outputs[0]
1044+
if not isinstance(array, np.ndarray):
10451045
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
10461046
# So, a constant-value of type sequence is not folded, but it can be used
10471047
# to optimize subsequent operations when possible.
10481048
logger.info(
10491049
"Skip storing constant folded value %s due to unsupported type %s.",
1050-
irvalue.name,
1051-
type(value),
1050+
original_value.name,
1051+
type(array),
10521052
)
10531053
return None
10541054

1055-
tensor = ir.tensor(value)
1056-
tensor.name = irvalue.name
1057-
irvalue.const_value = tensor
1055+
tensor = ir.tensor(array)
1056+
tensor.name = original_value.name
1057+
initializer = ir.Value(
1058+
name=original_value.name,
1059+
type=ir.TensorType(ir.DataType(tensor.dtype)),
1060+
shape=tensor.shape, # type: ignore[arg-type]
1061+
const_value=tensor,
1062+
)
10581063

1059-
if value.size > self.output_size_limit:
1064+
if array.size > self.output_size_limit:
10601065
# Handle examples like Transpose(weight) to be folded even if the size is large,
10611066
# as long as weight has no other uses. This won't increase model size.
10621067
removed_input_size = 0
@@ -1065,25 +1070,23 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
10651070
array = _get_numpy_value(input)
10661071
if array is not None:
10671072
removed_input_size += array.size
1068-
increased_size = value.size - removed_input_size
1073+
increased_size = array.size - removed_input_size
10691074
if increased_size > 0:
10701075
logger.info(
10711076
"Skip storing constant folded nvalue %s due to large size %s.",
1072-
irvalue.name,
1073-
value.size,
1077+
original_value.name,
1078+
array.size,
10741079
)
10751080
return None
10761081

10771082
logger.debug(
1078-
"New constant for value %s dtype: %s shape: %s",
1079-
irvalue.name,
1080-
value.dtype,
1081-
value.shape,
1083+
"New Initializer for value %s dtype: %s shape: %s",
1084+
original_value.name,
1085+
array.dtype,
1086+
array.shape,
10821087
)
10831088

1084-
attributes = ir.convenience.convert_attributes({"value": tensor})
1085-
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
1086-
return node
1089+
return initializer
10871090

10881091
def process_node(self, node: ir.Node) -> Replacement | None:
10891092
"""Process a node and return a Replacement if the node can be replaced."""
@@ -1109,7 +1112,13 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11091112
self._do_inference(node)
11101113

11111114
if node.domain not in self._opset_imports:
1115+
logger.debug(
1116+
"Skipping constant folding for node %r due to missing opset import for domain %r.",
1117+
node.name,
1118+
node.domain,
1119+
)
11121120
return None
1121+
11131122
version = self._opset_imports[node.domain]
11141123
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
11151124
for optimizer in op_optimizers:
@@ -1153,7 +1162,7 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11531162
)
11541163
return None
11551164

1156-
# Ensure all node inputs are constants
1165+
# Ensure all node inputs are constants or initializers
11571166
if any(x.const_value is None for x in node.inputs if x is not None):
11581167
return None
11591168

@@ -1227,10 +1236,13 @@ def convert(av):
12271236
if outputs is None:
12281237
return None
12291238
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
1230-
replacement = self.new_constant(node, outputs)
1231-
if replacement is None:
1239+
new_initializer_value = self.new_initializer(node, outputs)
1240+
if new_initializer_value is None:
12321241
return None
1233-
return Replacement(replacement.outputs, [replacement])
1242+
# Add the new initializer to the graph
1243+
assert node.graph is not None
1244+
node.graph.register_initializer(new_initializer_value)
1245+
return Replacement([new_initializer_value], [])
12341246
else:
12351247
logger.warning(
12361248
"Skipping constant folding for op %s with multiple outputs.", node.op_type
@@ -1244,7 +1256,6 @@ def replace_node(
12441256

12451257
# Record the names of the values that has contributed to the replacement
12461258
_record_contributing_values(node, replacement)
1247-
12481259
ir.convenience.replace_nodes_and_values(
12491260
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
12501261
)

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def test_fold_add(self):
3636
"""
3737

3838
optimized = self._fold(model)
39-
self.assertEqual(len(optimized.graph), 2)
40-
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
39+
self.assertEqual(len(optimized.graph), 1)
40+
self.assertIn("four", optimized.graph.initializers)
4141

4242
def test_fold_cast_like(self):
4343
model = """
@@ -51,8 +51,8 @@ def test_fold_cast_like(self):
5151
"""
5252

5353
optimized = self._fold(model)
54-
self.assertEqual(len(optimized.graph), 2)
55-
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
54+
self.assertEqual(len(optimized.graph), 1)
55+
self.assertIn("four", optimized.graph.initializers)
5656

5757
def test_fold_shape(self):
5858
model = """
@@ -67,8 +67,8 @@ def test_fold_shape(self):
6767
"""
6868

6969
optimized = self._fold(model)
70-
self.assertEqual(len(optimized.graph), 2)
71-
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
70+
self.assertEqual(len(optimized.graph), 1)
71+
self.assertIn("four", optimized.graph.initializers)
7272

7373
def test_fold_shape_slice(self):
7474
model = """
@@ -83,8 +83,8 @@ def test_fold_shape_slice(self):
8383
"""
8484

8585
optimized = self._fold(model)
86-
self.assertEqual(len(optimized.graph), 2)
87-
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
86+
self.assertEqual(len(optimized.graph), 1)
87+
self.assertIn("four", optimized.graph.initializers)
8888

8989
def test_fold_if_cond(self):
9090
model = """
@@ -130,9 +130,11 @@ def test_fold_inside_if_branch(self):
130130
optimized = self._fold(model)
131131
self.assertEqual(len(optimized.graph), 1)
132132
then_graph = optimized.graph[0].attributes["then_branch"].as_graph()
133-
self.assertEqual(len(then_graph), 2)
133+
self.assertEqual(len(then_graph), 1)
134+
self.assertIn("temp", then_graph.initializers)
134135
else_graph = optimized.graph[0].attributes["else_branch"].as_graph()
135-
self.assertEqual(len(else_graph), 2)
136+
self.assertEqual(len(else_graph), 1)
137+
self.assertIn("temp", else_graph.initializers)
136138

137139
def test_fold_if_propagate(self):
138140
model = """
@@ -154,9 +156,8 @@ def test_fold_if_propagate(self):
154156
"""
155157

156158
optimized = self._fold(model)
157-
self.assertEqual(len(optimized.graph), 2)
158-
self.assertEqual(optimized.graph[0].outputs[0].name, "m_square")
159-
self.assertEqual(optimized.graph[0].op_type, "Constant")
159+
self.assertEqual(len(optimized.graph), 1)
160+
self.assertIn("m_square", optimized.graph.initializers)
160161

161162
def test_fold_redundant_cast(self):
162163
model = """
@@ -209,8 +210,8 @@ def test_shape_inference(self):
209210
"""
210211

211212
optimized = self._fold(model, onnx_shape_inference=True)
212-
self.assertEqual(len(optimized.graph), 2)
213-
self.assertEqual(optimized.graph[0].outputs[0].name, "C")
213+
self.assertEqual(len(optimized.graph), 1)
214+
self.assertIn("C", optimized.graph.initializers)
214215

215216
def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(
216217
self,
@@ -614,7 +615,8 @@ def test_input_size_limit(self):
614615
# Since there is no increase in model-size, output-size is not a concern.
615616
optimized = self._fold(model, input_size_limit=256 * 256, output_size_limit=256 * 256)
616617
ops = [node.op_type for node in optimized.graph]
617-
self.assertEqual(ops, ["Constant", "Add"])
618+
self.assertEqual(ops, ["Add"])
619+
self.assertIn("w_squared", optimized.graph.initializers)
618620

619621
def test_transpose_is_always_folded(self):
620622
model_text = """
@@ -633,7 +635,8 @@ def test_transpose_is_always_folded(self):
633635
# Input size limit will not prevent folding of Transpose op
634636
optimized = self._fold(model, input_size_limit=1)
635637
ops = [node.op_type for node in optimized.graph]
636-
self.assertEqual(ops, ["Constant"])
638+
self.assertEqual(ops, [])
639+
self.assertIn("z", optimized.graph.initializers)
637640

638641
def test_node_is_folded_if_specified_as_should_fold(self):
639642
model_text = """
@@ -656,9 +659,10 @@ def test_node_is_folded_if_specified_as_should_fold(self):
656659
model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None
657660
)
658661
ops = [node.op_type for node in optimized.graph]
659-
self.assertEqual(ops, ["Constant"])
662+
self.assertEqual(ops, [])
663+
self.assertIn("z", optimized.graph.initializers)
660664
np.testing.assert_array_equal(
661-
optimized.graph.node(0).attributes["value"].as_tensor().numpy(),
665+
optimized.graph.initializers["z"].const_value,
662666
np.ones((42, 42), dtype=np.int64),
663667
)
664668

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"ml_dtypes",
2929
"numpy",
30-
"onnx_ir>=0.1.10,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
30+
"onnx_ir>=0.1.12,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
3131
"onnx>=1.16",
3232
"packaging",
3333
"typing_extensions>=4.10",

0 commit comments

Comments
 (0)