1212
1313import onnxscript
1414import onnxscript .onnx_types as ot
15- import onnxscript .rewriter .basic_rules as basic_rules
1615from onnxscript import ir
1716from onnxscript .onnx_opset import opset18
17+ from onnxscript .rewriter .rules .common import _basic_rules
1818
1919FLOAT = onnx .TensorProto .FLOAT
2020
@@ -98,7 +98,7 @@ def _check_model(
9898 ]
9999 )
100100 def test_basic_optimization_rules_identity (self , _ : str , model : ir .Model ):
101- rule_set = basic_rules .basic_optimization_rules ()
101+ rule_set = _basic_rules .basic_optimization_rules ()
102102 model_proto = ir .serde .serialize_model (model )
103103 rule_set .apply_to_model (model )
104104 rewritten_model = ir .serde .serialize_model (model )
@@ -126,7 +126,7 @@ def test_basic_optimization_rules_identity(self, _: str, model: ir.Model):
126126 ]
127127 )
128128 def test_basic_optimization_rules_transpose_transpose (self , _ : str , model : ir .Model ):
129- rule_set = basic_rules .basic_optimization_rules ()
129+ rule_set = _basic_rules .basic_optimization_rules ()
130130 model_proto = ir .serde .serialize_model (model )
131131 rule_set .apply_to_model (model )
132132 rewritten_model = ir .serde .serialize_model (model )
@@ -153,7 +153,7 @@ def cast_cast_model(x):
153153 ]
154154 )
155155 def test_cast_cast_rule (self , _ : str , type1 , type2 , type3 ):
156- rule = basic_rules .cast_cast_rule
156+ rule = _basic_rules .cast_cast_rule
157157 model_proto = self ._double_cast_model (type1 , type2 , type3 )
158158 model = ir .serde .deserialize_model (model_proto )
159159 rule .apply_to_model (model )
@@ -172,7 +172,7 @@ def test_cast_cast_rule(self, _: str, type1, type2, type3):
172172 ]
173173 )
174174 def test_cast_identity_rule (self , _ : str , model : ir .Model ):
175- rule_set = basic_rules .basic_optimization_rules ()
175+ rule_set = _basic_rules .basic_optimization_rules ()
176176 model_proto = ir .serde .serialize_model (model )
177177 rule_set .apply_to_model (model )
178178 rewritten_model = ir .serde .serialize_model (model )
@@ -228,7 +228,7 @@ def test_cast_identity_rule(self, _: str, model: ir.Model):
228228 def test_expand_identity_rule (
229229 self , _ : str , model : ir .Model , expected_nodes : tuple [str , ...]
230230 ):
231- rule_set = basic_rules .basic_optimization_rules ()
231+ rule_set = _basic_rules .basic_optimization_rules ()
232232 model_proto = ir .serde .serialize_model (model )
233233 rule_set .apply_to_model (model )
234234 rewritten_model = ir .serde .serialize_model (model )
@@ -310,7 +310,7 @@ def test_expand_identity_rule(
310310 ]
311311 )
312312 def test_unsqueeze_unsqueeze_rule (self , _ : str , model : ir .Model ):
313- rule_set = basic_rules .basic_optimization_rules ()
313+ rule_set = _basic_rules .basic_optimization_rules ()
314314 model_proto = ir .serde .serialize_model (model )
315315 rule_set .apply_to_model (model )
316316 rewritten_model = ir .serde .serialize_model (model )
@@ -369,7 +369,7 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
369369 ]
370370 )
371371 def test_reshape_reshape_rule (self , _ : str , model : ir .Model ):
372- rule_set = basic_rules .basic_optimization_rules ()
372+ rule_set = _basic_rules .basic_optimization_rules ()
373373 model_proto = ir .serde .serialize_model (model )
374374 rule_set .apply_to_model (model )
375375 rewritten_model = ir .serde .serialize_model (model )
@@ -420,15 +420,15 @@ def _slices_split_models(cls):
420420 def test_slices_split_rule (self ):
421421 for model_proto in self ._slices_split_models ():
422422 ir_model = ir .serde .deserialize_model (model_proto )
423- rule_set = basic_rules .basic_optimization_rules ()
423+ rule_set = _basic_rules .basic_optimization_rules ()
424424 rule_set .apply_to_model (ir_model )
425425 rewritten_model = ir .serde .serialize_model (ir_model )
426426
427427 self .assertEqual (["Split" ], [n .op_type for n in rewritten_model .graph .node ])
428428 self ._check_model (model_proto , rewritten_model )
429429
430430 def test_squeeze_reshape_1d_rule (self ):
431- rule = basic_rules .squeeze_reshape_1d_rule
431+ rule = _basic_rules .squeeze_reshape_1d_rule
432432
433433 def check (model_script , expected_count ) -> None :
434434 model_proto = model_script .to_model_proto ()
0 commit comments