@@ -76,7 +76,7 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
7676
7777def _process_constant_node (node : ir .Node ) -> None :
7878 """Sets const_value of output value of a Constant op node."""
79- if node . op_type != "Constant" or node . domain != "" :
79+ if not _is_onnx_op ( node , "Constant" ) :
8080 return
8181 if len (node .attributes ) != 1 :
8282 return
@@ -1099,8 +1099,12 @@ def process_node(self, node: ir.Node) -> Replacement | None:
10991099 self ._modified = True
11001100 # TODO(rama): consider merging type/other info from both values
11011101
1102+ # Propagate const_value, and manually find out shape and type
1103+ # to avoid potentially expensive shape inference on large tensors.
1104+ if _is_onnx_op (node , "Constant" ):
1105+ _process_constant_node (node )
11021106 # Do incremental shape inference
1103- if self .shape_inference and not _is_control_flow_op (node ):
1107+ elif self .shape_inference and not _is_control_flow_op (node ):
11041108 self ._do_inference (node )
11051109
11061110 if node .domain not in self ._opset_imports :
@@ -1118,6 +1122,10 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11181122 output = [output ]
11191123 return Replacement (output , context .nodes )
11201124
1125+ if _is_onnx_op (node , "Constant" ):
1126+ logger .debug ("Skipping constant folding for Constant node %r" , node .name )
1127+ return None
1128+
11211129 if _is_control_flow_op (node ):
11221130 logger .info (
11231131 "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet" ,
@@ -1137,10 +1145,6 @@ def process_node(self, node: ir.Node) -> Replacement | None:
11371145 )
11381146 return None
11391147
1140- if _is_onnx_op (node , "Constant" ):
1141- _process_constant_node (node )
1142- return None
1143-
11441148 if any (x .is_graph_input () for x in node .inputs if x is not None ):
11451149 logger .info (
11461150 "Skipping constant folding for node %r because it is graph input to preserve graph signature" ,
0 commit comments