Skip to content

Commit 28a8f56

Browse files
authored
Fix constant in constant folding (#2622)
This PR moves the processing of constant ops upward to return before node-level shape type inference (including serialization) and optimizer optimization. Essentially, avoiding serializing constant ops (potentially large weights in LLMs) reduces the export time in optimize_ir. Before this PR: <img width="2119" height="282" alt="Screenshot 2025-10-09 141403" src="https://github.com/user-attachments/assets/4d50d3ee-ce84-4f8b-bc20-f56497d7dad1" /> After this PR: <img width="2390" height="319" alt="Screenshot 2025-10-09 141238" src="https://github.com/user-attachments/assets/f455942e-dde8-4a07-a94e-a5f817358579" />
1 parent 59c3d32 commit 28a8f56

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
7676

7777
def _process_constant_node(node: ir.Node) -> None:
7878
"""Sets const_value of output value of a Constant op node."""
79-
if node.op_type != "Constant" or node.domain != "":
79+
if not _is_onnx_op(node, "Constant"):
8080
return
8181
if len(node.attributes) != 1:
8282
return
@@ -1099,8 +1099,12 @@ def process_node(self, node: ir.Node) -> Replacement | None:
10991099
self._modified = True
11001100
# TODO(rama): consider merging type/other info from both values
11011101

1102+
# Propagate const_value, and manually find out shape and type
1103+
# to avoid potentially expensive shape inference on large tensors.
1104+
if _is_onnx_op(node, "Constant"):
1105+
_process_constant_node(node)
11021106
# Do incremental shape inference
1103-
if self.shape_inference and not _is_control_flow_op(node):
1107+
elif self.shape_inference and not _is_control_flow_op(node):
11041108
self._do_inference(node)
11051109

11061110
if node.domain not in self._opset_imports:
@@ -1118,6 +1122,10 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11181122
output = [output]
11191123
return Replacement(output, context.nodes)
11201124

1125+
if _is_onnx_op(node, "Constant"):
1126+
logger.debug("Skipping constant folding for Constant node %r", node.name)
1127+
return None
1128+
11211129
if _is_control_flow_op(node):
11221130
logger.info(
11231131
"Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet",
@@ -1137,10 +1145,6 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11371145
)
11381146
return None
11391147

1140-
if _is_onnx_op(node, "Constant"):
1141-
_process_constant_node(node)
1142-
return None
1143-
11441148
if any(x.is_graph_input() for x in node.inputs if x is not None):
11451149
logger.info(
11461150
"Skipping constant folding for node %r because it is graph input to preserve graph signature",

0 commit comments

Comments
 (0)