Skip to content

Commit d80575d

Browse files
Keep creating constants when constants are folded inside ir.Function (#2679)
Fixes #2673 --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 1a27df1 commit d80575d

File tree

2 files changed

+90
-30
lines changed

2 files changed

+90
-30
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,46 +1050,79 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10501050
e,
10511051
)
10521052

1053-
def new_initializer(self, node: ir.Node, array) -> ir.Value | None:
1054-
original_value = node.outputs[0]
1055-
if not isinstance(array, np.ndarray):
1056-
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
1057-
# So, a constant-value of type sequence is not folded, but it can be used
1058-
# to optimize subsequent operations when possible.
1053+
def _prepare_folded_tensor(
1054+
self, node: ir.Node, output_name: str, output_array: np.ndarray | Any
1055+
) -> ir.Tensor | None:
1056+
"""
1057+
Shared helper for constant/init creation:
1058+
- Validates the folded Python value is a numpy ndarray.
1059+
- Wraps it in an ir.Tensor and names it.
1060+
- Applies output_size_limit logic with input-usage compensation.
1061+
Returns the ir.Tensor or None if it should be skipped.
1062+
"""
1063+
if not isinstance(output_array, np.ndarray):
10591064
logger.info(
10601065
"Skip storing constant folded value %s due to unsupported type %s.",
1061-
original_value.name,
1062-
type(array),
1066+
output_name,
1067+
type(output_array),
10631068
)
10641069
return None
10651070

1066-
tensor = ir.tensor(array)
1067-
tensor.name = original_value.name
1068-
initializer = ir.Value(
1069-
name=original_value.name,
1070-
type=ir.TensorType(ir.DataType(tensor.dtype)),
1071-
shape=tensor.shape, # type: ignore[arg-type]
1072-
const_value=tensor,
1073-
)
1071+
tensor = ir.tensor(output_array)
1072+
tensor.name = output_name
10741073

1075-
if array.size > self.output_size_limit:
1076-
# Handle examples like Transpose(weight) to be folded even if the size is large,
1077-
# as long as weight has no other uses. This won't increase model size.
1074+
# Size gating (shared logic)
1075+
if output_array.size > self.output_size_limit:
10781076
removed_input_size = 0
1079-
for input in node.inputs:
1080-
if (input is not None) and (len(input.uses()) == 1):
1081-
array = _get_numpy_value(input)
1082-
if array is not None:
1083-
removed_input_size += array.size
1084-
increased_size = array.size - removed_input_size
1077+
for input_val in node.inputs:
1078+
if (input_val is not None) and (len(input_val.uses()) == 1):
1079+
input_array = _get_numpy_value(input_val)
1080+
if input_array is not None:
1081+
removed_input_size += input_array.size
1082+
increased_size = output_array.size - removed_input_size
10851083
if increased_size > 0:
10861084
logger.info(
1087-
"Skip storing constant folded nvalue %s due to large size %s.",
1088-
original_value.name,
1089-
array.size,
1085+
"Skip storing constant folded array %s due to large size %s.",
1086+
output_name,
1087+
output_array.size,
10901088
)
10911089
return None
10921090

1091+
return tensor
1092+
1093+
def new_constant(self, node: ir.Node, array: np.ndarray | Any) -> ir.Node | None:
1094+
"""Create a new Constant node with the given array as its value."""
1095+
original_value = node.outputs[0]
1096+
1097+
tensor = self._prepare_folded_tensor(node, original_value.name, array)
1098+
if tensor is None:
1099+
return None
1100+
1101+
logger.debug(
1102+
"New constant for value %s dtype: %s shape: %s",
1103+
original_value.name,
1104+
array.dtype,
1105+
array.shape,
1106+
)
1107+
1108+
node = ir.Node("", "Constant", inputs=[], attributes=(ir.AttrTensor("value", tensor),))
1109+
return node
1110+
1111+
def new_initializer(self, node: ir.Node, array: np.ndarray | Any) -> ir.Value | None:
1112+
"""Create a new initializer value with the given array as its value."""
1113+
original_value = node.outputs[0]
1114+
1115+
tensor = self._prepare_folded_tensor(node, original_value.name, array)
1116+
if tensor is None:
1117+
return None
1118+
1119+
initializer = ir.Value(
1120+
name=original_value.name,
1121+
type=ir.TensorType(ir.DataType(tensor.dtype)),
1122+
shape=tensor.shape, # type: ignore[arg-type]
1123+
const_value=tensor,
1124+
)
1125+
10931126
logger.debug(
10941127
"New Initializer for value %s dtype: %s shape: %s",
10951128
original_value.name,
@@ -1099,7 +1132,7 @@ def new_initializer(self, node: ir.Node, array) -> ir.Value | None:
10991132

11001133
return initializer
11011134

1102-
def process_node(self, node: ir.Node) -> Replacement | None:
1135+
def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None:
11031136
"""Process a node and return a Replacement if the node can be replaced."""
11041137
for i, value in enumerate(node.inputs):
11051138
sym_value = self._state.get_sym_value(value)
@@ -1252,6 +1285,12 @@ def convert(av):
12521285
if outputs is None:
12531286
return None
12541287
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
1288+
# We don't support initializers in functions, so we need to create Constant nodes
1289+
if is_function:
1290+
replacement = self.new_constant(node, outputs)
1291+
if replacement is None:
1292+
return None
1293+
return Replacement(replacement.outputs, [replacement])
12551294
new_initializer_value = self.new_initializer(node, outputs)
12561295
if new_initializer_value is None:
12571296
return None
@@ -1301,7 +1340,8 @@ def visit_attribute(self, attr: ir.Attr) -> None:
13011340
self.visit_graph(graph)
13021341

13031342
def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
1304-
replacement = self.process_node(node)
1343+
is_function = isinstance(root, ir.Function)
1344+
replacement = self.process_node(node, is_function=is_function)
13051345
if replacement is None:
13061346
# No change. Process attributes.
13071347
for attr in node.attributes.values():

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,26 @@ def test_attribute_reference(self):
721721
optimized = self._fold(model)
722722
self.assertEqual(len(optimized.graph), 2)
723723

724+
def test_constant_folding_creates_constant_nodes_in_function(self):
725+
model = """
726+
<ir_version: 9, opset_import: ["this" : 1, "" : 19]>
727+
model (float x) => (float return_val) {
728+
return_val = this.function (x)
729+
}
730+
<domain: "this", opset_import: ["" : 19]>
731+
function (x) => (return_val) {
732+
tmp = Constant <value_int=1> ()
733+
tmp_0 = Cast <to=1> (tmp)
734+
return_val = Sub (tmp_0, x)
735+
}
736+
"""
737+
optimized = self._fold(model)
738+
self.assertEqual(len(optimized.functions), 1)
739+
for func in optimized.functions.values():
740+
# Ensure that constant folding has created constant nodes in the function
741+
constant_nodes = [n for n in func.graph if n.op_type == "Constant"]
742+
self.assertEqual(len(constant_nodes), 1)
743+
724744

725745
if __name__ == "__main__":
726746
unittest.main()

0 commit comments

Comments
 (0)