@@ -1050,46 +1050,79 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10501050 e ,
10511051 )
10521052
1053- def new_initializer (self , node : ir .Node , array ) -> ir .Value | None :
1054- original_value = node .outputs [0 ]
1055- if not isinstance (array , np .ndarray ):
1056- # ONNX does not have a way to represent non-tensor constants, eg. a sequence.
1057- # So, a constant-value of type sequence is not folded, but it can be used
1058- # to optimize subsequent operations when possible.
1053+ def _prepare_folded_tensor (
1054+ self , node : ir .Node , output_name : str , output_array : np .ndarray | Any
1055+ ) -> ir .Tensor | None :
1056+ """
1057+ Shared helper for constant/init creation:
1058+ - Validates the folded Python value is a numpy ndarray.
1059+ - Wraps it in an ir.Tensor and names it.
1060+ - Applies output_size_limit logic with input-usage compensation.
1061+ Returns the ir.Tensor or None if it should be skipped.
1062+ """
1063+ if not isinstance (output_array , np .ndarray ):
10591064 logger .info (
10601065 "Skip storing constant folded value %s due to unsupported type %s." ,
1061- original_value . name ,
1062- type (array ),
1066+ output_name ,
1067+ type (output_array ),
10631068 )
10641069 return None
10651070
1066- tensor = ir .tensor (array )
1067- tensor .name = original_value .name
1068- initializer = ir .Value (
1069- name = original_value .name ,
1070- type = ir .TensorType (ir .DataType (tensor .dtype )),
1071- shape = tensor .shape , # type: ignore[arg-type]
1072- const_value = tensor ,
1073- )
1071+ tensor = ir .tensor (output_array )
1072+ tensor .name = output_name
10741073
1075- if array .size > self .output_size_limit :
1076- # Handle examples like Transpose(weight) to be folded even if the size is large,
1077- # as long as weight has no other uses. This won't increase model size.
1074+ # Size gating (shared logic)
1075+ if output_array .size > self .output_size_limit :
10781076 removed_input_size = 0
1079- for input in node .inputs :
1080- if (input is not None ) and (len (input .uses ()) == 1 ):
1081- array = _get_numpy_value (input )
1082- if array is not None :
1083- removed_input_size += array .size
1084- increased_size = array .size - removed_input_size
1077+ for input_val in node .inputs :
1078+ if (input_val is not None ) and (len (input_val .uses ()) == 1 ):
1079+ input_array = _get_numpy_value (input_val )
1080+ if input_array is not None :
1081+ removed_input_size += input_array .size
1082+ increased_size = output_array .size - removed_input_size
10851083 if increased_size > 0 :
10861084 logger .info (
1087- "Skip storing constant folded nvalue %s due to large size %s." ,
1088- original_value . name ,
1089- array .size ,
1085+ "Skip storing constant folded array %s due to large size %s." ,
1086+ output_name ,
1087+ output_array .size ,
10901088 )
10911089 return None
10921090
1091+ return tensor
1092+
1093+ def new_constant (self , node : ir .Node , array : np .ndarray | Any ) -> ir .Node | None :
1094+ """Create a new Constant node with the given array as its value."""
1095+ original_value = node .outputs [0 ]
1096+
1097+ tensor = self ._prepare_folded_tensor (node , original_value .name , array )
1098+ if tensor is None :
1099+ return None
1100+
1101+ logger .debug (
1102+ "New constant for value %s dtype: %s shape: %s" ,
1103+ original_value .name ,
1104+ array .dtype ,
1105+ array .shape ,
1106+ )
1107+
1108+ node = ir .Node ("" , "Constant" , inputs = [], attributes = (ir .AttrTensor ("value" , tensor ),))
1109+ return node
1110+
1111+ def new_initializer (self , node : ir .Node , array : np .ndarray | Any ) -> ir .Value | None :
1112+ """Create a new initializer value with the given array as its value."""
1113+ original_value = node .outputs [0 ]
1114+
1115+ tensor = self ._prepare_folded_tensor (node , original_value .name , array )
1116+ if tensor is None :
1117+ return None
1118+
1119+ initializer = ir .Value (
1120+ name = original_value .name ,
1121+ type = ir .TensorType (ir .DataType (tensor .dtype )),
1122+ shape = tensor .shape , # type: ignore[arg-type]
1123+ const_value = tensor ,
1124+ )
1125+
10931126 logger .debug (
10941127 "New Initializer for value %s dtype: %s shape: %s" ,
10951128 original_value .name ,
@@ -1099,7 +1132,7 @@ def new_initializer(self, node: ir.Node, array) -> ir.Value | None:
10991132
11001133 return initializer
11011134
1102- def process_node (self , node : ir .Node ) -> Replacement | None :
1135+ def process_node (self , node : ir .Node , is_function : bool ) -> Replacement | None :
11031136 """Process a node and return a Replacement if the node can be replaced."""
11041137 for i , value in enumerate (node .inputs ):
11051138 sym_value = self ._state .get_sym_value (value )
@@ -1252,6 +1285,12 @@ def convert(av):
12521285 if outputs is None :
12531286 return None
12541287 if len (node .outputs ) == 1 and not isinstance (outputs , (tuple , list )):
1288+ # We don't support initializers in functions, so we need to create Constant nodes
1289+ if is_function :
1290+ replacement = self .new_constant (node , outputs )
1291+ if replacement is None :
1292+ return None
1293+ return Replacement (replacement .outputs , [replacement ])
12551294 new_initializer_value = self .new_initializer (node , outputs )
12561295 if new_initializer_value is None :
12571296 return None
@@ -1301,7 +1340,8 @@ def visit_attribute(self, attr: ir.Attr) -> None:
13011340 self .visit_graph (graph )
13021341
13031342 def visit_node (self , node : ir .Node , root : ir .Graph | ir .Function ) -> None :
1304- replacement = self .process_node (node )
1343+ is_function = isinstance (root , ir .Function )
1344+ replacement = self .process_node (node , is_function = is_function )
13051345 if replacement is None :
13061346 # No change. Process attributes.
13071347 for attr in node .attributes .values ():
0 commit comments