@@ -2241,6 +2241,72 @@ def test_const_fold_cast_with_const(self):
22412241 self .run_and_compare (["res" ], {"X" : np .random .randn (* shape ).astype (np .int64 )}, model_proto ,
22422242 "Cast" , 0 )
22432243
2244+ def test_const_fold_add (self ):
2245+ shape = (6 , 6 )
2246+ const_tensor1 = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2247+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
2248+ const_tensor2 = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2249+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
2250+ node1 = helper .make_node ("Constant" , [], ["const1" ], value = const_tensor1 )
2251+ node2 = helper .make_node ("Constant" , [], ["const2" ], value = const_tensor2 )
2252+ node3 = helper .make_node ("Add" , ["const1" , "const2" ], ["add" ])
2253+ node4 = helper .make_node ("Add" , ["add" , "X" ], ["res" ])
2254+
2255+ graph = helper .make_graph (
2256+ [node1 , node2 , node3 , node4 ],
2257+ "test_const_fold_add" ,
2258+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , shape )],
2259+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , shape )],
2260+ )
2261+
2262+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
2263+ self .run_and_compare (["res" ], {"X" : np .random .randn (* shape ).astype (np .float32 )}, model_proto ,
2264+ "Add" , 1 )
2265+
2266+ def test_const_fold_sub (self ):
2267+ shape = (6 , 6 )
2268+ const_tensor1 = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2269+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
2270+ const_tensor2 = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2271+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
2272+ node1 = helper .make_node ("Constant" , [], ["const1" ], value = const_tensor1 )
2273+ node2 = helper .make_node ("Constant" , [], ["const2" ], value = const_tensor2 )
2274+ node3 = helper .make_node ("Sub" , ["const1" , "const2" ], ["sub" ])
2275+ node4 = helper .make_node ("Sub" , ["sub" , "X" ], ["res" ])
2276+
2277+ graph = helper .make_graph (
2278+ [node1 , node2 , node3 , node4 ],
2279+ "test_const_fold_sub" ,
2280+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , shape )],
2281+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , shape )],
2282+ )
2283+
2284+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
2285+ self .run_and_compare (["res" ], {"X" : np .random .randn (* shape ).astype (np .float32 )}, model_proto ,
2286+ "Sub" , 1 )
2287+
2288+ def test_const_fold_mul (self ):
2289+ shape = (6 , 6 )
2290+ const_tensor1 = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2291+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
2292+ const_tensor2 = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
2293+ vals = np .random .randn (* shape ).flatten ().astype (np .float32 ))
2294+ node1 = helper .make_node ("Constant" , [], ["const1" ], value = const_tensor1 )
2295+ node2 = helper .make_node ("Constant" , [], ["const2" ], value = const_tensor2 )
2296+ node3 = helper .make_node ("Mul" , ["const1" , "const2" ], ["mul" ])
2297+ node4 = helper .make_node ("Mul" , ["mul" , "X" ], ["res" ])
2298+
2299+ graph = helper .make_graph (
2300+ [node1 , node2 , node3 , node4 ],
2301+ "test_const_fold_mul" ,
2302+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , shape )],
2303+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , shape )],
2304+ )
2305+
2306+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
2307+ self .run_and_compare (["res" ], {"X" : np .random .randn (* shape ).astype (np .float32 )}, model_proto ,
2308+ "Mul" , 1 )
2309+
22442310 def test_const_fold_split (self ):
22452311 shape = (2 , 6 , 1 )
22462312 const_tensor = helper .make_tensor (name = 'const_tensor' , data_type = TensorProto .FLOAT , dims = shape ,
0 commit comments