Skip to content

Commit 456a6bc

Browse files
justinchubyCopilot
andauthored
Update constant folding behavior for large tensors (#2488)
Suggested by #2466, I updated the constant folder logic to allow **Constant folding customization:** * Replaced the `always_fold_ops` parameter with a `should_fold` callable that determines on a per-node basis whether folding should occur. This allows users to specify more complex folding policies and makes the API more explicit. (`FoldConstantsPass`, `fold_constants`) [[1]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L902-R904) [[2]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L913-R918) [[3]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1248-R1268) [[4]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1263-R1285) [[5]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1276-R1295) **Logging and diagnostics improvements:** * Upgraded logging throughout the folding process to provide more informative messages, including reasons for skipping nodes (e.g., control flow, non-deterministic ops, large inputs, or graph inputs) and explicit logging when `should_fold` returns a decision. [[1]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L964-R958) [[2]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L990-R984) [[3]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1075-R1141) **Code cleanup and minor fixes:** * Removed the unused `_update_type` function. Fix #2466 cc @iksnagreb --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent a925acc commit 456a6bc

File tree

2 files changed

+122
-60
lines changed

2 files changed

+122
-60
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 95 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
import math
1111
import typing
12-
from typing import Any, Callable, Collection, Iterable, Sequence, Union
12+
from typing import Any, Callable, Iterable, Sequence, Union
1313

1414
import numpy as np
1515
import onnx
@@ -34,6 +34,13 @@
3434
}
3535
)
3636

37+
# A list of ops to always fold regardless of their input size limits, as long as
38+
# they are the single consumer of the large input tensors
39+
_DEFAULT_ALWAYS_FOLD_OPS = frozenset(
40+
{
41+
("", "Transpose"),
42+
}
43+
)
3744

3845
logger = logging.getLogger(__name__)
3946

@@ -332,12 +339,6 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None:
332339
return None
333340

334341

335-
def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None:
336-
if type is not None:
337-
# TODO: merge types
338-
value.type = type
339-
340-
341342
def _get_input_element_type(node: ir.Node, index: int) -> int:
342343
input = _get_input(node, index)
343344
if input is not None and input.type is not None:
@@ -899,9 +900,10 @@ class FoldConstantsPass(ir.passes.InPlacePass):
899900
shape_inference: Whether to perform shape inference.
900901
input_size_limit: Maximum size of input tensors to fold.
901902
output_size_limit: Maximum size of output tensors to fold.
902-
always_fold_ops: Collection of op types that should always be folded.
903-
For ops from the default opset, only op_type is neede (e.g. "Transpose"),
904-
otherwise specify the domain with ``{domain}::{op_type}``.
903+
should_fold: An optional function that takes a node and returns True if
904+
the node should be considered for folding.
905+
The function should return True/False value to indicate if this particular
906+
node should be folded, or None to use the default folding rules.
905907
"""
906908

907909
def __init__(
@@ -910,18 +912,12 @@ def __init__(
910912
shape_inference: bool,
911913
input_size_limit: int,
912914
output_size_limit: int,
913-
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
915+
should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
914916
) -> None:
915917
self.shape_inference = shape_inference
916918
self.input_size_limit = input_size_limit
917919
self.output_size_limit = output_size_limit
918-
ops = []
919-
for name in always_fold_ops:
920-
domain, op_type = name.split("::", 1) if "::" in name else ("", name)
921-
if domain == "ai.onnx":
922-
domain = ""
923-
ops.append((domain, op_type))
924-
self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops)
920+
self.should_fold = should_fold
925921

926922
self._opset_imports: dict[str, int] = {}
927923
self._counts: dict[str, int] = {}
@@ -961,7 +957,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
961957
input_data = {k: v for k, v in input_data.items() if v is not None}
962958
if any(t is None for t in input_types.values()):
963959
logger.debug(
964-
"Skipping shape inference for node %s due to missing input type.",
960+
"Skipping shape inference for node %r due to missing input type.",
965961
node.name,
966962
)
967963
else:
@@ -987,7 +983,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
987983
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
988984
except Exception as e:
989985
logger.debug(
990-
"Skipping shape inference for node %s due to exception: %s",
986+
"Skipping shape inference for node %r due to exception: %s",
991987
node.name,
992988
e,
993989
)
@@ -1072,62 +1068,102 @@ def process_node(self, node: ir.Node) -> Replacement | None:
10721068
output = [output]
10731069
return Replacement(output, context.nodes)
10741070

1075-
if _is_control_flow_op(node) or _is_non_deterministic_op(node):
1071+
if _is_control_flow_op(node):
1072+
logger.info(
1073+
"Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet",
1074+
node.name,
1075+
node.domain,
1076+
node.op_type,
1077+
)
1078+
1079+
return None
1080+
1081+
if _is_non_deterministic_op(node):
1082+
logger.info(
1083+
"Skipping constant folding for non-deterministic op %r (%s::%s)",
1084+
node.name,
1085+
node.domain,
1086+
node.op_type,
1087+
)
10761088
return None
10771089

10781090
if _is_onnx_op(node, "Constant"):
10791091
_process_constant_node(node)
10801092
return None
10811093

10821094
if any(x.is_graph_input() for x in node.inputs if x is not None):
1083-
# Do not fold any graph inputs to preserve graph signature
1095+
logger.info(
1096+
"Skipping constant folding for node %r because it is graph input to preserve graph signature",
1097+
node.name,
1098+
)
10841099
return None
10851100

10861101
# Ensure all node inputs are constants
10871102
if any(x.const_value is None for x in node.inputs if x is not None):
1088-
if logger.isEnabledFor(logging.DEBUG):
1089-
logger.debug(
1090-
"Skipping constant folding for node %s because it has non-constant inputs",
1091-
node,
1092-
[x.name for x in node.inputs if x is not None],
1093-
)
10941103
return None
10951104

1096-
input_tensors = [x.const_value if x is not None else None for x in node.inputs]
1097-
if any(
1098-
tensor.size > self.input_size_limit
1099-
for tensor in input_tensors
1100-
if tensor is not None
1101-
):
1102-
if (node.domain, node.op_type) in self.always_fold_ops and all(
1103-
len(input.consumers()) == 1 for input in node.inputs if input is not None
1104-
):
1105-
# If the op is in always_fold_ops and all inputs are used only by this node,
1106-
# we can still fold it even if the input size exceeds the limit.
1107-
logger.debug(
1108-
"Folding large constant for node %s because it is in the always_fold_ops list",
1109-
node,
1105+
should_fold = self.should_fold(node)
1106+
1107+
if should_fold is False:
1108+
logger.info(
1109+
"Skipping constant folding for node %r because should_fold returned False",
1110+
node.name,
1111+
)
1112+
return None
1113+
1114+
elif should_fold is None:
1115+
# Use default rules to decide whether to fold the node:
1116+
# - ConstantOfShape is preserved to avoid increasing model size unnecessarily
1117+
# - If the any tensor input size exceeds the input_size_limit, skip folding the node
1118+
if _is_onnx_op(node, "ConstantOfShape"):
1119+
logger.info(
1120+
"Skipping constant folding for node %r because ConstantOfShape is preserved by default",
1121+
node.name,
11101122
)
1111-
else:
1112-
# Skip folding large tensors
1113-
if logger.isEnabledFor(logging.DEBUG):
1114-
input_sizes = [
1115-
tensor.size for tensor in input_tensors if tensor is not None
1116-
]
1117-
logger.debug(
1118-
"Skipping constant folding for node %s due to large input size: %s",
1119-
node,
1120-
input_sizes,
1121-
)
11221123
return None
11231124

1125+
input_tensors = [x.const_value if x is not None else None for x in node.inputs]
1126+
large_inputs = [
1127+
tensor is not None and tensor.size > self.input_size_limit
1128+
for tensor in input_tensors
1129+
]
1130+
if any(large_inputs):
1131+
# Decide whether to fold large constants
1132+
assert len(node.inputs) == len(large_inputs)
1133+
if (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS and all(
1134+
len(input.consumers()) == 1 or (not is_large)
1135+
for input, is_large in zip(node.inputs, large_inputs)
1136+
if input is not None
1137+
):
1138+
# If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node,
1139+
# we can still fold it even if the input size exceeds the limit
1140+
pass
1141+
else:
1142+
# Skip folding large tensors
1143+
if logger.isEnabledFor(logging.INFO):
1144+
input_sizes = [
1145+
tensor.size for tensor in input_tensors if tensor is not None
1146+
]
1147+
logger.info(
1148+
"Skipping constant folding for node %r due to large input sizes: %s",
1149+
node,
1150+
input_sizes,
1151+
)
1152+
return None
1153+
else:
1154+
logger.info(
1155+
"Constant folding node %r because should_fold returned True",
1156+
node.name,
1157+
)
1158+
11241159
input_values = [_get_numpy_value(x) for x in node.inputs]
11251160

11261161
def convert(av):
11271162
if av.type == ir.AttributeType.TENSOR:
11281163
return ir.serde.serialize_tensor(av.value)
11291164
return av.value
11301165

1166+
# TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node
11311167
attr_values = {name: convert(attr) for name, attr in node.attributes.items()}
11321168
outputs = _reference_evaluator.evaluate(
11331169
node.domain, node.op_type, version, *input_values, **attr_values
@@ -1137,7 +1173,7 @@ def convert(av):
11371173
return None
11381174
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
11391175
replacement = self.new_constant(node, outputs)
1140-
if _is_onnx_op(node, "ConstantOfShape") or replacement is None:
1176+
if replacement is None:
11411177
return None
11421178
return Replacement(replacement.outputs, [replacement])
11431179
else:
@@ -1245,7 +1281,7 @@ def fold_constants(
12451281
onnx_shape_inference: bool = False,
12461282
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
12471283
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
1248-
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
1284+
should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
12491285
) -> FoldConstantsResult:
12501286
"""
12511287
Applies constant folding optimization to the model.
@@ -1260,10 +1296,9 @@ def fold_constants(
12601296
output_size_limit: The maximum size of output tensors
12611297
that can be stored after constant folding. Defaults to
12621298
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
1263-
always_fold_ops: A collection of op types that should always be folded,
1264-
regardless of their input or output sizes. For ops from the default opset,
1265-
only op_type is neede (e.g. "Transpose"), otherwise specify the domain
1266-
with ``{domain}::{op_type}``.
1299+
should_fold: An optional function that takes a node and returns True if
1300+
the node should be considered for folding, False if it should not be folded,
1301+
or None to use the default rules. Defaults to a function that always returns None.
12671302
12681303
Returns:
12691304
An instance of `FoldConstantsResult`.
@@ -1273,6 +1308,6 @@ def fold_constants(
12731308
shape_inference=onnx_shape_inference,
12741309
input_size_limit=input_size_limit,
12751310
output_size_limit=output_size_limit,
1276-
always_fold_ops=always_fold_ops,
1311+
should_fold=should_fold,
12771312
)
12781313
return folder_pass(model) # type: ignore[return-value]

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,33 @@ def test_transpose_is_always_folded(self):
581581
ops = [node.op_type for node in optimized.graph]
582582
self.assertEqual(ops, ["Constant"])
583583

584+
def test_node_is_folded_if_specified_as_should_fold(self):
585+
model_text = """
586+
<ir_version: 10, opset_import: [ "" : 20]>
587+
agraph (float[M, 256] x) => (float[42, 42] z)
588+
<int64[2] w = {42, 42}>
589+
{
590+
z = ConstantOfShape <value: tensor = int64[1] {1}> (w)
591+
}
592+
"""
593+
model = ir.from_onnx_text(model_text)
594+
595+
# ConstantOfShape is not folded by default
596+
optimized = self._fold(model)
597+
ops = [node.op_type for node in optimized.graph]
598+
self.assertEqual(ops, ["ConstantOfShape"])
599+
600+
# But ConstantOfShape is folded when specified in should_fold
601+
optimized = self._fold(
602+
model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None
603+
)
604+
ops = [node.op_type for node in optimized.graph]
605+
self.assertEqual(ops, ["Constant"])
606+
np.testing.assert_array_equal(
607+
optimized.graph.node(0).attributes["value"].as_tensor().numpy(),
608+
np.ones((42, 42), dtype=np.int64),
609+
)
610+
584611
def test_multi_graph_identity_output_preserves_output_name(self):
585612
model = """
586613
<ir_version: 10, opset_import: ["" : 20]>

0 commit comments

Comments
 (0)