Skip to content

Commit 821015a

Browse files
StonesjtuKaiyu Shi
andauthored
Add Conv-Affine(Mul+Add) and hardswish fusion (#2472)
Close #2468 - Absorbs Affine into Conv: - Mul + Add + Conv ==> Conv - Conv + Mul + Add ==> Conv - Fuse HardSwish: - Add + Clip + Div ==> HardSigmoid - HardSigmoid + Mul ==> HardSwish - Add + Clip + Mul + Div ==> HardSwish (Since the order of operator matters, I have to create different rewrite pattern for this) May not be generic enough, but works for us in `paddleOCRv4` model. Another question is hardswish is introduced in opset-v14, will onnxscript handles older opset version or rewrite rules take care of this? --------- Co-authored-by: Kaiyu Shi <kaiyu@bytedance.com>
1 parent 0e79b62 commit 821015a

File tree

5 files changed

+494
-0
lines changed

5 files changed

+494
-0
lines changed

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,21 @@
22
# Licensed under the MIT License.
33
__all__ = [
44
"add_0_rule",
5+
"affine_conv_fusion_rule",
56
"cast_cast_rule",
67
"cast_constant_of_shape_rule",
78
"cast_constant_of_shape_without_value_rule",
89
"collapse_slice_rule",
910
"collapse_slice2_rule",
11+
"conv_affine_fusion_rule",
1012
"div_by_1_rule",
1113
"dropout_inference_rule",
1214
"dropout_zero_rule",
1315
"flatten_to_reshape_rule",
1416
"fuse_batchnorm_into_conv_rule",
1517
"fuse_batchnorm_into_conv_transpose_rule",
1618
"fuse_batchnorm_into_gemm_rule",
19+
"fuse_hardswish_rules",
1720
"fuse_pad_into_conv_integer_rule",
1821
"fuse_pad_into_conv_rule",
1922
"min_min_rule",
@@ -76,6 +79,11 @@
7679
fuse_batchnorm_into_conv_transpose_rule,
7780
fuse_batchnorm_into_gemm_rule,
7881
)
82+
from onnxscript.rewriter.rules.common._fuse_conv_affine import (
83+
affine_conv_fusion_rule,
84+
conv_affine_fusion_rule,
85+
)
86+
from onnxscript.rewriter.rules.common._fuse_hardswish import fuse_hardswish_rules
7987
from onnxscript.rewriter.rules.common._fuse_pad_into_conv import (
8088
fuse_pad_into_conv_integer_rule,
8189
fuse_pad_into_conv_rule,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Absorbs affine operation into convolution (best effort):
4+
- Conv(Mul(Add(x))) -> Conv (only conv without padding can be fused)
5+
- Add(Mul(Conv)) -> Conv (for all convolutions)
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import numpy as np
11+
import onnx_ir as ir
12+
13+
from onnxscript.rewriter import pattern
14+
from onnxscript.rewriter._basics import MatchResult
15+
from onnxscript.rewriter._ir_utils import get_const_value, get_singleton_value
16+
17+
18+
class _ConvAffineFusionBase(pattern.RewriteRuleClassBase):
19+
def check(
20+
self,
21+
context,
22+
x: ir.Value,
23+
w: ir.Value,
24+
b: ir.Value,
25+
scale: ir.Value,
26+
offset: ir.Value,
27+
conv_out: ir.Value,
28+
) -> MatchResult:
29+
check_result = MatchResult()
30+
if get_const_value(w) is None:
31+
return check_result.fail("The weight of Conv should be constant")
32+
if get_const_value(b) is None:
33+
return check_result.fail("The bias of Conv should be constant")
34+
if get_singleton_value(scale) is None:
35+
return check_result.fail("Operand for Mul should be constant scalar value")
36+
if get_singleton_value(offset) is None:
37+
return check_result.fail("Operand for Add should be constant scalar value")
38+
return check_result
39+
40+
41+
class AffineConvFusion(_ConvAffineFusionBase):
42+
"""Pattern: scalar Mul + scalar Add + Conv (1x1) --> Conv(1x1)"""
43+
44+
def pattern(
45+
self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value
46+
) -> ir.Value:
47+
return op.Conv(
48+
x * scale + offset,
49+
w,
50+
b,
51+
pads=[0, 0, 0, 0],
52+
_allow_other_attributes=True,
53+
_outputs=["conv_out"],
54+
)
55+
56+
def rewrite(
57+
self,
58+
op: ir.tape.Tape,
59+
x: ir.Value,
60+
w: ir.Value,
61+
b: ir.Value,
62+
scale: ir.Value,
63+
offset: ir.Value,
64+
conv_out: ir.Value,
65+
) -> ir.Value:
66+
scale_value = scale.const_value.numpy()
67+
offset_value = offset.const_value.numpy()
68+
w_value = w.const_value.numpy()
69+
b_value = b.const_value.numpy()
70+
scaled_w_value = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled")
71+
offset_bias = ir.tensor(
72+
b_value + np.sum(w_value * offset_value, axis=(1, 2, 3), keepdims=False)
73+
)
74+
offset_bias = op.initializer(offset_bias, b.name + "_offset")
75+
conv_attributes = conv_out.producer().attributes
76+
return op.Conv(x, scaled_w_value, offset_bias, **conv_attributes)
77+
78+
79+
class ConvAffineFusion(_ConvAffineFusionBase):
80+
"""Pattern: Conv + scalar Mul + scalar Add --> Conv(1x1)"""
81+
82+
def pattern(
83+
self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value
84+
) -> ir.Value:
85+
return (
86+
op.Conv(x, w, b, _allow_other_attributes=True, _outputs=["conv_out"]) * scale
87+
+ offset
88+
)
89+
90+
def rewrite(
91+
self,
92+
op: ir.tape.Tape,
93+
x: ir.Value,
94+
w: ir.Value,
95+
b: ir.Value,
96+
scale: ir.Value,
97+
offset: ir.Value,
98+
conv_out: ir.Value,
99+
) -> ir.Value:
100+
scale_value = scale.const_value.numpy()
101+
offset_value = offset.const_value.numpy()
102+
w_value = w.const_value.numpy()
103+
b_value = b.const_value.numpy()
104+
scaled_w_weight = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled")
105+
offset_bias = ir.tensor(b_value * scale_value + offset_value)
106+
offset_bias = op.initializer(offset_bias, b.name + "_offset")
107+
conv_attributes = conv_out.producer().attributes
108+
return op.Conv(x, scaled_w_weight, offset_bias, **conv_attributes)
109+
110+
111+
affine_conv_fusion_rule = AffineConvFusion().rule()
112+
conv_affine_fusion_rule = ConvAffineFusion().rule()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import unittest
4+
5+
import numpy as np
6+
7+
from onnxscript import ir
8+
from onnxscript.rewriter import rewrite, testing
9+
from onnxscript.rewriter.rules.common import (
10+
affine_conv_fusion_rule,
11+
conv_affine_fusion_rule,
12+
)
13+
14+
15+
class FuseConvAffineTest(unittest.TestCase):
16+
def clone_model(self, model: ir.Model) -> ir.Model:
17+
return ir.from_proto(ir.to_proto(model))
18+
19+
def test_conv_affine_fusion(self):
20+
tape = ir.tape.Tape()
21+
x = ir.Input(
22+
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
23+
)
24+
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
25+
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
26+
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
27+
offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset"))
28+
29+
conv_out = tape.op("Conv", [x, w, b], attributes={"pads": [1, 1, 1, 1]})
30+
mul_out = tape.op("Mul", [conv_out, scale])
31+
z = tape.op(
32+
"Add",
33+
[mul_out, offset],
34+
output=ir.Input(
35+
"z",
36+
shape=ir.Shape([1, 3, 32, 32]),
37+
type=ir.TensorType(ir.DataType.FLOAT),
38+
),
39+
)
40+
41+
model = ir.Model(
42+
ir.Graph(
43+
inputs=[x],
44+
outputs=[z],
45+
nodes=tape.nodes,
46+
initializers=tape.initializers,
47+
opset_imports={"": 17},
48+
),
49+
ir_version=8,
50+
)
51+
rewritten_model = self.clone_model(model)
52+
rewritten_model = rewrite(
53+
rewritten_model,
54+
pattern_rewrite_rules=[conv_affine_fusion_rule],
55+
)
56+
# Check that Mul and Add are fused into Conv
57+
self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes())
58+
59+
# Check that the results are numerically equal
60+
rng = np.random.default_rng(42)
61+
inputs = [
62+
rng.random((1, 3, 32, 32), dtype=np.float32),
63+
]
64+
testing.assert_numerically_equal(model, rewritten_model, inputs)
65+
66+
def test_affine_conv_fusion_without_pad(self):
67+
tape = ir.tape.Tape()
68+
x = ir.Input(
69+
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
70+
)
71+
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
72+
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
73+
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
74+
offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset"))
75+
76+
mul_out = tape.op("Mul", [x, scale])
77+
z = tape.op(
78+
"Add",
79+
[mul_out, offset],
80+
output=ir.Input(
81+
"z",
82+
shape=ir.Shape([1, 3, 32, 32]),
83+
type=ir.TensorType(ir.DataType.FLOAT),
84+
),
85+
)
86+
conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]})
87+
88+
model = ir.Model(
89+
ir.Graph(
90+
inputs=[x],
91+
outputs=[conv_out],
92+
nodes=tape.nodes,
93+
initializers=tape.initializers,
94+
opset_imports={"": 17},
95+
),
96+
ir_version=8,
97+
)
98+
rewritten_model = self.clone_model(model)
99+
rewritten_model = rewrite(
100+
rewritten_model,
101+
pattern_rewrite_rules=[affine_conv_fusion_rule],
102+
)
103+
# Check that Mul and Add are fused into Conv
104+
self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes())
105+
106+
# Check that the results are numerically equal
107+
rng = np.random.default_rng(42)
108+
inputs = [
109+
rng.random((1, 3, 32, 32), dtype=np.float32),
110+
]
111+
testing.assert_numerically_equal(model, rewritten_model, inputs)
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)