Skip to content

Commit 54de741

Browse files
authored
Refactor rewrite rules into the rewriter.rules namespace (#2531)
Organize all rules into a directory that is not with the rewriter infrastructure: - `onnxscript.rewriter.rules.common.*` for existing rules - `onnxscript.rewriter.rules.fusion.*` for onnx fusion rules --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 7b04774 commit 54de741

34 files changed

+289
-207
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,35 @@
2222
import onnx_ir.passes.common as common_passes
2323

2424
from onnxscript import ir
25-
from onnxscript.rewriter import (
26-
basic_rules,
27-
broadcast_to_matmul,
28-
cast_constant_of_shape,
29-
collapse_slices,
30-
fuse_pad_into_conv,
31-
fuse_relus_clips,
32-
no_op,
33-
pattern,
34-
redundant_scatter_nd,
35-
)
25+
from onnxscript.rewriter import pattern
3626
from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus
3727
from onnxscript.rewriter._rewrite_rule import (
3828
RewriterContext,
3929
RewriteRule,
4030
RewriteRuleClassBase,
4131
RewriteRuleSet,
4232
)
33+
from onnxscript.rewriter.rules.common import (
34+
_basic_rules,
35+
_broadcast_to_matmul,
36+
_cast_constant_of_shape,
37+
_collapse_slices,
38+
_fuse_pad_into_conv,
39+
_fuse_relus_clips,
40+
_no_op,
41+
_redundant_scatter_nd,
42+
)
4343

4444
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
4545
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
46-
*no_op.rules.rules, # TODO: merge this rule into constant folding?
47-
*broadcast_to_matmul.rules.rules,
48-
*cast_constant_of_shape.rules.rules,
49-
*collapse_slices.rules.rules,
50-
*fuse_relus_clips.fuse_relus_clips_rules().rules,
51-
*basic_rules.basic_optimization_rules().rules,
52-
*redundant_scatter_nd.rules.rules,
53-
*fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules,
46+
*_no_op.rules, # TODO: merge this rule into constant folding?
47+
*_broadcast_to_matmul.rules,
48+
*_cast_constant_of_shape.rules,
49+
*_collapse_slices.rules,
50+
*_fuse_relus_clips.rules,
51+
*_basic_rules.basic_optimization_rules(),
52+
*_redundant_scatter_nd.rules,
53+
*_fuse_pad_into_conv.rules,
5454
)
5555

5656

onnxscript/rewriter/onnx_fusions/_onnx_fusions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import onnx_ir as ir
66

7-
from onnxscript.rewriter.onnx_fusions import _rms_normalization, _rotary_embedding
7+
from onnxscript.rewriter.rules.fusion import _rms_normalization, _rotary_embedding
88

99

1010
def _get_onnx_opset_version(model: ir.Model) -> int | None:

onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from parameterized import parameterized
99

1010
import onnxscript
11-
import onnxscript.rewriter.onnx_fusions as onnx_fusions
11+
from onnxscript.rewriter import onnx_fusions
1212
from onnxscript.rewriter.models import _rotary_embedding_models
1313

1414

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets
99
import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization
1010
from onnxscript.optimizer import optimize
11-
from onnxscript.rewriter import gemm_to_matmul_add, rewrite
11+
from onnxscript.rewriter import rewrite
1212
from onnxscript.rewriter.ort_fusions import (
1313
instance_to_group_normalization,
1414
softmax,
@@ -33,6 +33,7 @@
3333
fuse_skip_layer_normalization,
3434
fuse_skip_rms_normalization,
3535
)
36+
from onnxscript.rewriter.rules.common import _gemm_to_matmul_add
3637

3738
ORT_PATTERN_REWRITE_RULES = [
3839
*softmax.rules.rules,
@@ -133,7 +134,7 @@ def optimize_for_ort(
133134
- The optimized `ir.Model` after applying transformer-specific fusions.
134135
- A dictionary with a count of each of the fusions applied.
135136
"""
136-
rewrite(model, [gemm_to_matmul_add.rule])
137+
rewrite(model, [_gemm_to_matmul_add.gemm_to_matmul_add_rule])
137138
model, fusion_count = fuse_xformers(
138139
model,
139140
debug=debug,

onnxscript/rewriter/pattern_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import onnxscript.optimizer
1313
from onnxscript import FLOAT, ir, script
1414
from onnxscript import opset17 as op
15-
from onnxscript.rewriter import cast_constant_of_shape, pattern
15+
from onnxscript.rewriter import pattern
16+
from onnxscript.rewriter.rules.common import _cast_constant_of_shape
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -306,7 +307,7 @@ def test_delayed_run_provides_correct_bindings_for_multiple_matches(self):
306307
"""
307308
)
308309
model = ir.serde.deserialize_model(model_proto)
309-
count = cast_constant_of_shape.rules.apply_to_model(model)
310+
count = _cast_constant_of_shape.rules.apply_to_model(model)
310311
self.assertEqual(count, 2)
311312
self.assertEqual(len(model.graph), 2)
312313
self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
__all__ = [
4+
"add_0_rule",
5+
"cast_cast_rule",
6+
"cast_constant_of_shape_rule",
7+
"cast_constant_of_shape_without_value_rule",
8+
"collapse_slice_rule",
9+
"collapse_slice2_rule",
10+
"div_by_1_rule",
11+
"dropout_inference_rule",
12+
"dropout_zero_rule",
13+
"fuse_batchnorm_into_conv_rule",
14+
"fuse_batchnorm_into_conv_transpose_rule",
15+
"fuse_batchnorm_into_gemm_rule",
16+
"fuse_pad_into_conv_integer_rule",
17+
"fuse_pad_into_conv_rule",
18+
"gemm_to_matmul_add_rule",
19+
"matmul_add_to_gemm_rule",
20+
"mul_by_1_rule",
21+
"no_op_cast_rule",
22+
"no_op_dynamic_scatter_nd_rule",
23+
"no_op_expand_rule",
24+
"no_op_static_scatter_nd_rule",
25+
"no_op_transpose_rule",
26+
"normalize_pad_format_conv_integer_rule",
27+
"normalize_pad_format_conv_rule",
28+
"one_reshape_matmul_reshape_rule",
29+
"reshape_reshape_rule",
30+
"slice_split_rule",
31+
"squeeze_reshape_1d_rule",
32+
"sub_0_rule",
33+
"successive_clip_relu_rule",
34+
"successive_clip_rule",
35+
"successive_relu_clip_rule",
36+
"successive_relu_rule",
37+
"transpose_a_matmul_add_to_gemm_rule",
38+
"transpose_ab_matmul_add_to_gemm_rule",
39+
"transpose_b_matmul_add_to_gemm_rule",
40+
"transpose_transpose_rule",
41+
"two_reshapes_matmul_reshape_rule",
42+
"unsqueeze_unsqueeze_rule",
43+
]
44+
45+
from onnxscript.rewriter.rules.common._basic_rules import (
46+
cast_cast_rule,
47+
no_op_cast_rule,
48+
no_op_expand_rule,
49+
no_op_transpose_rule,
50+
reshape_reshape_rule,
51+
slice_split_rule,
52+
squeeze_reshape_1d_rule,
53+
transpose_transpose_rule,
54+
unsqueeze_unsqueeze_rule,
55+
)
56+
from onnxscript.rewriter.rules.common._broadcast_to_matmul import (
57+
one_reshape_matmul_reshape_rule,
58+
two_reshapes_matmul_reshape_rule,
59+
)
60+
from onnxscript.rewriter.rules.common._cast_constant_of_shape import (
61+
cast_constant_of_shape_rule,
62+
cast_constant_of_shape_without_value_rule,
63+
)
64+
from onnxscript.rewriter.rules.common._collapse_slices import (
65+
collapse_slice2_rule,
66+
collapse_slice_rule,
67+
)
68+
from onnxscript.rewriter.rules.common._fuse_batchnorm import (
69+
fuse_batchnorm_into_conv_rule,
70+
fuse_batchnorm_into_conv_transpose_rule,
71+
fuse_batchnorm_into_gemm_rule,
72+
)
73+
from onnxscript.rewriter.rules.common._fuse_pad_into_conv import (
74+
fuse_pad_into_conv_integer_rule,
75+
fuse_pad_into_conv_rule,
76+
normalize_pad_format_conv_integer_rule,
77+
normalize_pad_format_conv_rule,
78+
)
79+
from onnxscript.rewriter.rules.common._fuse_relus_clips import (
80+
successive_clip_relu_rule,
81+
successive_clip_rule,
82+
successive_relu_clip_rule,
83+
successive_relu_rule,
84+
)
85+
from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule
86+
from onnxscript.rewriter.rules.common._matmul_add_to_gemm import (
87+
matmul_add_to_gemm_rule,
88+
transpose_a_matmul_add_to_gemm_rule,
89+
transpose_ab_matmul_add_to_gemm_rule,
90+
transpose_b_matmul_add_to_gemm_rule,
91+
)
92+
from onnxscript.rewriter.rules.common._no_op import (
93+
add_0_rule,
94+
div_by_1_rule,
95+
dropout_inference_rule,
96+
dropout_zero_rule,
97+
mul_by_1_rule,
98+
sub_0_rule,
99+
)
100+
from onnxscript.rewriter.rules.common._redundant_scatter_nd import (
101+
no_op_dynamic_scatter_nd_rule,
102+
no_op_static_scatter_nd_rule,
103+
)

onnxscript/rewriter/basic_rules.py renamed to onnxscript/rewriter/rules/common/_basic_rules.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,11 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
281281

282282
# Create rule instances
283283
cast_cast_rule = CastCast.rule()
284-
cast_identity_rule = CastIdentity.rule()
285-
expand_identity_rule = ExpandIdentity.rule()
284+
no_op_cast_rule = CastIdentity.rule()
285+
no_op_expand_rule = ExpandIdentity.rule()
286286
reshape_reshape_rule = ReshapeReshape.rule()
287287
slice_split_rule = SlicesSplit.rule()
288-
transpose_identity_rule = TransposeIdentity.rule()
288+
no_op_transpose_rule = TransposeIdentity.rule()
289289
transpose_transpose_rule = TransposeTranspose.rule()
290290
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
291291
squeeze_reshape_1d_rule = SqueezeReshape.rule()
@@ -309,11 +309,11 @@ def basic_optimization_rules() -> RewriteRuleSet:
309309
return RewriteRuleSet(
310310
[
311311
cast_cast_rule,
312-
cast_identity_rule,
313-
expand_identity_rule,
312+
no_op_cast_rule,
313+
no_op_expand_rule,
314314
reshape_reshape_rule,
315315
slice_split_rule,
316-
transpose_identity_rule,
316+
no_op_transpose_rule,
317317
transpose_transpose_rule,
318318
unsqueeze_unsqueeze_rule,
319319
squeeze_reshape_1d_rule,

onnxscript/rewriter/basic_rules_test.py renamed to onnxscript/rewriter/rules/common/_basic_rules_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
import onnxscript
1414
import onnxscript.onnx_types as ot
15-
import onnxscript.rewriter.basic_rules as basic_rules
1615
from onnxscript import ir
1716
from onnxscript.onnx_opset import opset18
17+
from onnxscript.rewriter.rules.common import _basic_rules
1818

1919
FLOAT = onnx.TensorProto.FLOAT
2020

@@ -98,7 +98,7 @@ def _check_model(
9898
]
9999
)
100100
def test_basic_optimization_rules_identity(self, _: str, model: ir.Model):
101-
rule_set = basic_rules.basic_optimization_rules()
101+
rule_set = _basic_rules.basic_optimization_rules()
102102
model_proto = ir.serde.serialize_model(model)
103103
rule_set.apply_to_model(model)
104104
rewritten_model = ir.serde.serialize_model(model)
@@ -126,7 +126,7 @@ def test_basic_optimization_rules_identity(self, _: str, model: ir.Model):
126126
]
127127
)
128128
def test_basic_optimization_rules_transpose_transpose(self, _: str, model: ir.Model):
129-
rule_set = basic_rules.basic_optimization_rules()
129+
rule_set = _basic_rules.basic_optimization_rules()
130130
model_proto = ir.serde.serialize_model(model)
131131
rule_set.apply_to_model(model)
132132
rewritten_model = ir.serde.serialize_model(model)
@@ -153,7 +153,7 @@ def cast_cast_model(x):
153153
]
154154
)
155155
def test_cast_cast_rule(self, _: str, type1, type2, type3):
156-
rule = basic_rules.cast_cast_rule
156+
rule = _basic_rules.cast_cast_rule
157157
model_proto = self._double_cast_model(type1, type2, type3)
158158
model = ir.serde.deserialize_model(model_proto)
159159
rule.apply_to_model(model)
@@ -172,7 +172,7 @@ def test_cast_cast_rule(self, _: str, type1, type2, type3):
172172
]
173173
)
174174
def test_cast_identity_rule(self, _: str, model: ir.Model):
175-
rule_set = basic_rules.basic_optimization_rules()
175+
rule_set = _basic_rules.basic_optimization_rules()
176176
model_proto = ir.serde.serialize_model(model)
177177
rule_set.apply_to_model(model)
178178
rewritten_model = ir.serde.serialize_model(model)
@@ -228,7 +228,7 @@ def test_cast_identity_rule(self, _: str, model: ir.Model):
228228
def test_expand_identity_rule(
229229
self, _: str, model: ir.Model, expected_nodes: tuple[str, ...]
230230
):
231-
rule_set = basic_rules.basic_optimization_rules()
231+
rule_set = _basic_rules.basic_optimization_rules()
232232
model_proto = ir.serde.serialize_model(model)
233233
rule_set.apply_to_model(model)
234234
rewritten_model = ir.serde.serialize_model(model)
@@ -310,7 +310,7 @@ def test_expand_identity_rule(
310310
]
311311
)
312312
def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
313-
rule_set = basic_rules.basic_optimization_rules()
313+
rule_set = _basic_rules.basic_optimization_rules()
314314
model_proto = ir.serde.serialize_model(model)
315315
rule_set.apply_to_model(model)
316316
rewritten_model = ir.serde.serialize_model(model)
@@ -369,7 +369,7 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
369369
]
370370
)
371371
def test_reshape_reshape_rule(self, _: str, model: ir.Model):
372-
rule_set = basic_rules.basic_optimization_rules()
372+
rule_set = _basic_rules.basic_optimization_rules()
373373
model_proto = ir.serde.serialize_model(model)
374374
rule_set.apply_to_model(model)
375375
rewritten_model = ir.serde.serialize_model(model)
@@ -420,15 +420,15 @@ def _slices_split_models(cls):
420420
def test_slices_split_rule(self):
421421
for model_proto in self._slices_split_models():
422422
ir_model = ir.serde.deserialize_model(model_proto)
423-
rule_set = basic_rules.basic_optimization_rules()
423+
rule_set = _basic_rules.basic_optimization_rules()
424424
rule_set.apply_to_model(ir_model)
425425
rewritten_model = ir.serde.serialize_model(ir_model)
426426

427427
self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node])
428428
self._check_model(model_proto, rewritten_model)
429429

430430
def test_squeeze_reshape_1d_rule(self):
431-
rule = basic_rules.squeeze_reshape_1d_rule
431+
rule = _basic_rules.squeeze_reshape_1d_rule
432432

433433
def check(model_script, expected_count) -> None:
434434
model_proto = model_script.to_model_proto()

0 commit comments

Comments
 (0)