@@ -36,8 +36,8 @@ def test_fold_add(self):
3636 """
3737
3838 optimized = self ._fold (model )
39- self .assertEqual (len (optimized .graph ), 2 )
40- self .assertEqual ( optimized .graph [ 0 ]. outputs [ 0 ]. name , "four" )
39+ self .assertEqual (len (optimized .graph ), 1 )
40+ self .assertIn ( "four" , optimized .graph . initializers )
4141
4242 def test_fold_cast_like (self ):
4343 model = """
@@ -51,8 +51,8 @@ def test_fold_cast_like(self):
5151 """
5252
5353 optimized = self ._fold (model )
54- self .assertEqual (len (optimized .graph ), 2 )
55- self .assertEqual ( optimized .graph [ 0 ]. outputs [ 0 ]. name , "four" )
54+ self .assertEqual (len (optimized .graph ), 1 )
55+ self .assertIn ( "four" , optimized .graph . initializers )
5656
5757 def test_fold_shape (self ):
5858 model = """
@@ -67,8 +67,8 @@ def test_fold_shape(self):
6767 """
6868
6969 optimized = self ._fold (model )
70- self .assertEqual (len (optimized .graph ), 2 )
71- self .assertEqual ( optimized .graph [ 0 ]. outputs [ 0 ]. name , "four" )
70+ self .assertEqual (len (optimized .graph ), 1 )
71+ self .assertIn ( "four" , optimized .graph . initializers )
7272
7373 def test_fold_shape_slice (self ):
7474 model = """
@@ -83,8 +83,8 @@ def test_fold_shape_slice(self):
8383 """
8484
8585 optimized = self ._fold (model )
86- self .assertEqual (len (optimized .graph ), 2 )
87- self .assertEqual ( optimized .graph [ 0 ]. outputs [ 0 ]. name , "four" )
86+ self .assertEqual (len (optimized .graph ), 1 )
87+ self .assertIn ( "four" , optimized .graph . initializers )
8888
8989 def test_fold_if_cond (self ):
9090 model = """
@@ -130,9 +130,11 @@ def test_fold_inside_if_branch(self):
130130 optimized = self ._fold (model )
131131 self .assertEqual (len (optimized .graph ), 1 )
132132 then_graph = optimized .graph [0 ].attributes ["then_branch" ].as_graph ()
133- self .assertEqual (len (then_graph ), 2 )
133+ self .assertEqual (len (then_graph ), 1 )
134+ self .assertIn ("temp" , then_graph .initializers )
134135 else_graph = optimized .graph [0 ].attributes ["else_branch" ].as_graph ()
135- self .assertEqual (len (else_graph ), 2 )
136+ self .assertEqual (len (else_graph ), 1 )
137+ self .assertIn ("temp" , else_graph .initializers )
136138
137139 def test_fold_if_propagate (self ):
138140 model = """
@@ -154,9 +156,8 @@ def test_fold_if_propagate(self):
154156 """
155157
156158 optimized = self ._fold (model )
157- self .assertEqual (len (optimized .graph ), 2 )
158- self .assertEqual (optimized .graph [0 ].outputs [0 ].name , "m_square" )
159- self .assertEqual (optimized .graph [0 ].op_type , "Constant" )
159+ self .assertEqual (len (optimized .graph ), 1 )
160+ self .assertIn ("m_square" , optimized .graph .initializers )
160161
161162 def test_fold_redundant_cast (self ):
162163 model = """
@@ -209,8 +210,8 @@ def test_shape_inference(self):
209210 """
210211
211212 optimized = self ._fold (model , onnx_shape_inference = True )
212- self .assertEqual (len (optimized .graph ), 2 )
213- self .assertEqual ( optimized .graph [ 0 ]. outputs [ 0 ]. name , "C" )
213+ self .assertEqual (len (optimized .graph ), 1 )
214+ self .assertIn ( "C" , optimized .graph . initializers )
214215
215216 def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split (
216217 self ,
@@ -614,7 +615,8 @@ def test_input_size_limit(self):
614615 # Since there is no increase in model-size, output-size is not a concern.
615616 optimized = self ._fold (model , input_size_limit = 256 * 256 , output_size_limit = 256 * 256 )
616617 ops = [node .op_type for node in optimized .graph ]
617- self .assertEqual (ops , ["Constant" , "Add" ])
618+ self .assertEqual (ops , ["Add" ])
619+ self .assertIn ("w_squared" , optimized .graph .initializers )
618620
619621 def test_transpose_is_always_folded (self ):
620622 model_text = """
@@ -633,7 +635,8 @@ def test_transpose_is_always_folded(self):
633635 # Input size limit will not prevent folding of Transpose op
634636 optimized = self ._fold (model , input_size_limit = 1 )
635637 ops = [node .op_type for node in optimized .graph ]
636- self .assertEqual (ops , ["Constant" ])
638+ self .assertEqual (ops , [])
639+ self .assertIn ("z" , optimized .graph .initializers )
637640
638641 def test_node_is_folded_if_specified_as_should_fold (self ):
639642 model_text = """
@@ -656,9 +659,10 @@ def test_node_is_folded_if_specified_as_should_fold(self):
656659 model , should_fold = lambda node : node .op_type == "ConstantOfShape" or None
657660 )
658661 ops = [node .op_type for node in optimized .graph ]
659- self .assertEqual (ops , ["Constant" ])
662+ self .assertEqual (ops , [])
663+ self .assertIn ("z" , optimized .graph .initializers )
660664 np .testing .assert_array_equal (
661- optimized .graph .node ( 0 ). attributes [ "value " ].as_tensor (). numpy () ,
665+ optimized .graph .initializers [ "z " ].const_value ,
662666 np .ones ((42 , 42 ), dtype = np .int64 ),
663667 )
664668
0 commit comments