1414
1515
1616class FoldConstantsTest (unittest .TestCase ):
17- def _fold (self , model : ir .Model | str , onnx_shape_inference = False , ** kwargs ):
17+ def _fold (
18+ self ,
19+ model : ir .Model | str ,
20+ onnx_shape_inference : bool = False ,
21+ dce : bool = True ,
22+ ** kwargs ,
23+ ):
1824 if isinstance (model , str ):
1925 model = ir .from_onnx_text (model )
2026 _constant_folding .fold_constants (
2127 model , onnx_shape_inference = onnx_shape_inference , ** kwargs
2228 )
23- optimizer .remove_unused_nodes (model )
29+ if dce :
30+ optimizer .remove_unused_nodes (model )
2431 # Ensure the model is valid after optimization
2532 onnx .checker .check_model (ir .serde .serialize_model (model ))
2633 return model
@@ -50,9 +57,16 @@ def test_fold_cast_like(self):
5057 }
5158 """
5259
53- optimized = self ._fold (model )
54- self .assertEqual (len (optimized .graph ), 1 )
60+ optimized = self ._fold (model , dce = False )
5561 self .assertIn ("four" , optimized .graph .initializers )
62+ np .testing .assert_equal (
63+ optimized .graph .initializers ["four" ].const_value , np .array (4.0 )
64+ )
65+ # Intermediates should be removed
66+ self .assertNotIn ("two_float" , optimized .graph .initializers )
67+
68+ optimized = self ._fold (model , dce = True )
69+ self .assertEqual (len (optimized .graph ), 1 )
5670
5771 def test_fold_shape (self ):
5872 model = """
@@ -66,9 +80,18 @@ def test_fold_shape(self):
6680 }
6781 """
6882
69- optimized = self ._fold (model )
70- self .assertEqual (len (optimized .graph ), 1 )
83+ optimized = self ._fold (model , dce = False )
7184 self .assertIn ("four" , optimized .graph .initializers )
85+ np .testing .assert_equal (
86+ optimized .graph .initializers ["four" ].const_value , np .array (4.0 )
87+ )
88+ # Intermediates should be removed
89+ self .assertNotIn ("two_float" , optimized .graph .initializers )
90+ self .assertNotIn ("rank" , optimized .graph .initializers )
91+ self .assertNotIn ("shape" , optimized .graph .initializers )
92+
93+ optimized = self ._fold (model , dce = True )
94+ self .assertEqual (len (optimized .graph ), 1 )
7295
7396 def test_fold_shape_slice (self ):
7497 model = """
0 commit comments