99import logging
1010import math
1111import typing
12- from typing import Any , Callable , Collection , Iterable , Sequence , Union
12+ from typing import Any , Callable , Iterable , Sequence , Union
1313
1414import numpy as np
1515import onnx
3434 }
3535)
3636
37+ # A list of ops to always fold regardless of their input size limits, as long as
38+ # they are the single consumer of the large input tensors
39+ _DEFAULT_ALWAYS_FOLD_OPS = frozenset (
40+ {
41+ ("" , "Transpose" ),
42+ }
43+ )
3744
3845logger = logging .getLogger (__name__ )
3946
@@ -332,12 +339,6 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None:
332339 return None
333340
334341
335- def _update_type (value : ir .Value , type : ir .TypeProtocol | None ) -> None :
336- if type is not None :
337- # TODO: merge types
338- value .type = type
339-
340-
341342def _get_input_element_type (node : ir .Node , index : int ) -> int :
342343 input = _get_input (node , index )
343344 if input is not None and input .type is not None :
@@ -899,9 +900,10 @@ class FoldConstantsPass(ir.passes.InPlacePass):
899900 shape_inference: Whether to perform shape inference.
900901 input_size_limit: Maximum size of input tensors to fold.
901902 output_size_limit: Maximum size of output tensors to fold.
902- always_fold_ops: Collection of op types that should always be folded.
903- For ops from the default opset, only op_type is neede (e.g. "Transpose"),
904- otherwise specify the domain with ``{domain}::{op_type}``.
903+ should_fold: An optional function that takes a node and returns True if
904+ the node should be considered for folding.
905+ The function should return True/False value to indicate if this particular
906+ node should be folded, or None to use the default folding rules.
905907 """
906908
907909 def __init__ (
@@ -910,18 +912,12 @@ def __init__(
910912 shape_inference : bool ,
911913 input_size_limit : int ,
912914 output_size_limit : int ,
913- always_fold_ops : Collection [ str ] = frozenset ([ "Transpose" ]) ,
915+ should_fold : Callable [[ ir . Node ], bool | None ] = lambda node : None ,
914916 ) -> None :
915917 self .shape_inference = shape_inference
916918 self .input_size_limit = input_size_limit
917919 self .output_size_limit = output_size_limit
918- ops = []
919- for name in always_fold_ops :
920- domain , op_type = name .split ("::" , 1 ) if "::" in name else ("" , name )
921- if domain == "ai.onnx" :
922- domain = ""
923- ops .append ((domain , op_type ))
924- self .always_fold_ops : frozenset [tuple [str , str ]] = frozenset (ops )
920+ self .should_fold = should_fold
925921
926922 self ._opset_imports : dict [str , int ] = {}
927923 self ._counts : dict [str , int ] = {}
@@ -961,7 +957,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
961957 input_data = {k : v for k , v in input_data .items () if v is not None }
962958 if any (t is None for t in input_types .values ()):
963959 logger .debug (
964- "Skipping shape inference for node %s due to missing input type." ,
960+ "Skipping shape inference for node %r due to missing input type." ,
965961 node .name ,
966962 )
967963 else :
@@ -987,7 +983,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
987983 output .type = ir .serde .deserialize_type_proto_for_type (inferred_type )
988984 except Exception as e :
989985 logger .debug (
990- "Skipping shape inference for node %s due to exception: %s" ,
986+ "Skipping shape inference for node %r due to exception: %s" ,
991987 node .name ,
992988 e ,
993989 )
@@ -1072,62 +1068,102 @@ def process_node(self, node: ir.Node) -> Replacement | None:
10721068 output = [output ]
10731069 return Replacement (output , context .nodes )
10741070
1075- if _is_control_flow_op (node ) or _is_non_deterministic_op (node ):
1071+ if _is_control_flow_op (node ):
1072+ logger .info (
1073+ "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet" ,
1074+ node .name ,
1075+ node .domain ,
1076+ node .op_type ,
1077+ )
1078+
1079+ return None
1080+
1081+ if _is_non_deterministic_op (node ):
1082+ logger .info (
1083+ "Skipping constant folding for non-deterministic op %r (%s::%s)" ,
1084+ node .name ,
1085+ node .domain ,
1086+ node .op_type ,
1087+ )
10761088 return None
10771089
10781090 if _is_onnx_op (node , "Constant" ):
10791091 _process_constant_node (node )
10801092 return None
10811093
10821094 if any (x .is_graph_input () for x in node .inputs if x is not None ):
1083- # Do not fold any graph inputs to preserve graph signature
1095+ logger .info (
1096+ "Skipping constant folding for node %r because it is graph input to preserve graph signature" ,
1097+ node .name ,
1098+ )
10841099 return None
10851100
10861101 # Ensure all node inputs are constants
10871102 if any (x .const_value is None for x in node .inputs if x is not None ):
1088- if logger .isEnabledFor (logging .DEBUG ):
1089- logger .debug (
1090- "Skipping constant folding for node %s because it has non-constant inputs" ,
1091- node ,
1092- [x .name for x in node .inputs if x is not None ],
1093- )
10941103 return None
10951104
1096- input_tensors = [x .const_value if x is not None else None for x in node .inputs ]
1097- if any (
1098- tensor .size > self .input_size_limit
1099- for tensor in input_tensors
1100- if tensor is not None
1101- ):
1102- if (node .domain , node .op_type ) in self .always_fold_ops and all (
1103- len (input .consumers ()) == 1 for input in node .inputs if input is not None
1104- ):
1105- # If the op is in always_fold_ops and all inputs are used only by this node,
1106- # we can still fold it even if the input size exceeds the limit.
1107- logger .debug (
1108- "Folding large constant for node %s because it is in the always_fold_ops list" ,
1109- node ,
1105+ should_fold = self .should_fold (node )
1106+
1107+ if should_fold is False :
1108+ logger .info (
1109+ "Skipping constant folding for node %r because should_fold returned False" ,
1110+ node .name ,
1111+ )
1112+ return None
1113+
1114+ elif should_fold is None :
1115+ # Use default rules to decide whether to fold the node:
1116+ # - ConstantOfShape is preserved to avoid increasing model size unnecessarily
1117+ # - If the any tensor input size exceeds the input_size_limit, skip folding the node
1118+ if _is_onnx_op (node , "ConstantOfShape" ):
1119+ logger .info (
1120+ "Skipping constant folding for node %r because ConstantOfShape is preserved by default" ,
1121+ node .name ,
11101122 )
1111- else :
1112- # Skip folding large tensors
1113- if logger .isEnabledFor (logging .DEBUG ):
1114- input_sizes = [
1115- tensor .size for tensor in input_tensors if tensor is not None
1116- ]
1117- logger .debug (
1118- "Skipping constant folding for node %s due to large input size: %s" ,
1119- node ,
1120- input_sizes ,
1121- )
11221123 return None
11231124
1125+ input_tensors = [x .const_value if x is not None else None for x in node .inputs ]
1126+ large_inputs = [
1127+ tensor is not None and tensor .size > self .input_size_limit
1128+ for tensor in input_tensors
1129+ ]
1130+ if any (large_inputs ):
1131+ # Decide whether to fold large constants
1132+ assert len (node .inputs ) == len (large_inputs )
1133+ if (node .domain , node .op_type ) in _DEFAULT_ALWAYS_FOLD_OPS and all (
1134+ len (input .consumers ()) == 1 or (not is_large )
1135+ for input , is_large in zip (node .inputs , large_inputs )
1136+ if input is not None
1137+ ):
1138+ # If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node,
1139+ # we can still fold it even if the input size exceeds the limit
1140+ pass
1141+ else :
1142+ # Skip folding large tensors
1143+ if logger .isEnabledFor (logging .INFO ):
1144+ input_sizes = [
1145+ tensor .size for tensor in input_tensors if tensor is not None
1146+ ]
1147+ logger .info (
1148+ "Skipping constant folding for node %r due to large input sizes: %s" ,
1149+ node ,
1150+ input_sizes ,
1151+ )
1152+ return None
1153+ else :
1154+ logger .info (
1155+ "Constant folding node %r because should_fold returned True" ,
1156+ node .name ,
1157+ )
1158+
11241159 input_values = [_get_numpy_value (x ) for x in node .inputs ]
11251160
11261161 def convert (av ):
11271162 if av .type == ir .AttributeType .TENSOR :
11281163 return ir .serde .serialize_tensor (av .value )
11291164 return av .value
11301165
1166+ # TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node
11311167 attr_values = {name : convert (attr ) for name , attr in node .attributes .items ()}
11321168 outputs = _reference_evaluator .evaluate (
11331169 node .domain , node .op_type , version , * input_values , ** attr_values
@@ -1137,7 +1173,7 @@ def convert(av):
11371173 return None
11381174 if len (node .outputs ) == 1 and not isinstance (outputs , (tuple , list )):
11391175 replacement = self .new_constant (node , outputs )
1140- if _is_onnx_op ( node , "ConstantOfShape" ) or replacement is None :
1176+ if replacement is None :
11411177 return None
11421178 return Replacement (replacement .outputs , [replacement ])
11431179 else :
@@ -1245,7 +1281,7 @@ def fold_constants(
12451281 onnx_shape_inference : bool = False ,
12461282 input_size_limit : int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
12471283 output_size_limit : int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
1248- always_fold_ops : Collection [ str ] = frozenset ([ "Transpose" ]) ,
1284+ should_fold : Callable [[ ir . Node ], bool | None ] = lambda node : None ,
12491285) -> FoldConstantsResult :
12501286 """
12511287 Applies constant folding optimization to the model.
@@ -1260,10 +1296,9 @@ def fold_constants(
12601296 output_size_limit: The maximum size of output tensors
12611297 that can be stored after constant folding. Defaults to
12621298 `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
1263- always_fold_ops: A collection of op types that should always be folded,
1264- regardless of their input or output sizes. For ops from the default opset,
1265- only op_type is neede (e.g. "Transpose"), otherwise specify the domain
1266- with ``{domain}::{op_type}``.
1299+ should_fold: An optional function that takes a node and returns True if
1300+ the node should be considered for folding, False if it should not be folded,
1301+ or None to use the default rules. Defaults to a function that always returns None.
12671302
12681303 Returns:
12691304 An instance of `FoldConstantsResult`.
@@ -1273,6 +1308,6 @@ def fold_constants(
12731308 shape_inference = onnx_shape_inference ,
12741309 input_size_limit = input_size_limit ,
12751310 output_size_limit = output_size_limit ,
1276- always_fold_ops = always_fold_ops ,
1311+ should_fold = should_fold ,
12771312 )
12781313 return folder_pass (model ) # type: ignore[return-value]
0 commit comments