Skip to content

Commit 5762a69

Browse files
authored
[Rewriter]: add fusion rules for successive Min/Max patterns (#2500)
This PR adds the following transformation: - Min(Min(X)) -> Min(X) - Max(Max(X)) -> Max(X) - Min(Max(X)) -> Clip(X) - Max(Min(X)) -> Clip(X)
1 parent e76bfe0 commit 5762a69

File tree

4 files changed

+632
-0
lines changed

4 files changed

+632
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_collapse_slices,
3838
_fuse_pad_into_conv,
3939
_fuse_relus_clips,
40+
_min_max_to_clip,
4041
_no_op,
4142
_redundant_scatter_nd,
4243
)
@@ -47,6 +48,7 @@
4748
*_broadcast_to_matmul.rules,
4849
*_cast_constant_of_shape.rules,
4950
*_collapse_slices.rules,
51+
*_min_max_to_clip.rules,
5052
*_fuse_relus_clips.rules,
5153
*_basic_rules.basic_optimization_rules(),
5254
*_redundant_scatter_nd.rules,

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
"fuse_batchnorm_into_gemm_rule",
1616
"fuse_pad_into_conv_integer_rule",
1717
"fuse_pad_into_conv_rule",
18+
"min_min_rule",
19+
"max_max_rule",
20+
"min_max_rule",
21+
"max_min_rule",
1822
"gemm_to_matmul_add_rule",
1923
"matmul_add_to_gemm_rule",
2024
"mul_by_1_rule",
@@ -89,6 +93,12 @@
8993
transpose_ab_matmul_add_to_gemm_rule,
9094
transpose_b_matmul_add_to_gemm_rule,
9195
)
96+
from onnxscript.rewriter.rules.common._min_max_to_clip import (
97+
max_max_rule,
98+
max_min_rule,
99+
min_max_rule,
100+
min_min_rule,
101+
)
92102
from onnxscript.rewriter.rules.common._no_op import (
93103
add_0_rule,
94104
div_by_1_rule,
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fuses successive Min/Max patterns in ONNX graphs.
4+
5+
Supported transformations:
6+
- Min(Min(X, c1, c2, ...), d1, d2, ...) → Min(X, fused_const)
7+
- Max(Max(X, c1, c2, ...), d1, d2, ...) → Max(X, fused_const)
8+
- Min(Max(X, lb1, lb2, ...), ub1, ub2, ...) → Clip(X, lb, ub)
9+
- Max(Min(X, ub1, ub2, ...), lb1, lb2, ...) → Clip(X, lb, ub)
10+
11+
Where:
12+
- fused_const is the reduction (min or max) over all constant inputs.
13+
- For Clip fusion:
14+
* All constant inputs must be scalars.
15+
* The effective lower bound is the maximum of all lower-bound constants.
16+
* The effective upper bound is the minimum of all upper-bound constants.
17+
18+
For the case of Max(Min(X, upper_bound), lower_bound):
19+
* The rule applies only if lower_bound ≤ upper_bound.
20+
21+
General constraints:
22+
- The first input may be any tensor.
23+
- All other inputs must be constant tensors (from Constant nodes or initializers).
24+
"""
25+
26+
import abc
27+
import functools
28+
from typing import ClassVar
29+
30+
import numpy as np
31+
import onnx_ir as ir
32+
33+
from onnxscript.rewriter._basics import MatchResult
34+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
35+
36+
37+
class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC):
38+
"""Base class for Min/Max fusion rewrites.
39+
40+
Constraints:
41+
- All inputs except the first must be constants (from Constant nodes or initializers).
42+
- If ``need_scalars`` is True (Clip fusion), all constants must be scalars.
43+
- If ``check_bounds`` is True (Clip fusion in the pattern Max(Min(X, upper_bound), lower_bound)), lower_bound ≤ upper_bound.
44+
"""
45+
46+
need_scalars: ClassVar = False
47+
check_bounds: ClassVar = False
48+
49+
@abc.abstractmethod
50+
def compute_constants(
51+
self,
52+
first_node: ir.Node,
53+
second_node: ir.Node,
54+
input_name: str = "",
55+
) -> list[tuple[ir.Tensor, str]]: ...
56+
57+
def rewrite(self, op, x, out1, out2):
58+
first_node = out1.producer()
59+
second_node = out2.producer()
60+
61+
# Compute new constants for the fused op
62+
constants = self.compute_constants(first_node, second_node, x.name)
63+
64+
initializers = [op.initializer(constant, name=name) for constant, name in constants]
65+
66+
return op.op(
67+
self.op_type,
68+
inputs=[x, *initializers],
69+
)
70+
71+
def _is_scalar(self, v: np.ndarray) -> bool:
72+
return np.isscalar(v) or np.size(v) == 1
73+
74+
def check(self, context, out1, out2, **_):
75+
"""Condition to check if we need to replace the pattern.
76+
77+
Conditions:
78+
- The min and max input nodes must not be graph inputs.
79+
- These inputs (except the first) must be constant values (from Constant nodes or initializers).
80+
- In the case of Min(Max) and Max(Min) patterns:
81+
* All inputs must be scalars (as Clip requires scalars).
82+
For Max(Min) pattern:
83+
* The lower bound must be less than or equal to the upper bound.
84+
85+
Returns:
86+
MatchResult:
87+
Success if we need to replace the pattern, Failure otherwise.
88+
"""
89+
del context # Not used
90+
check_result = MatchResult()
91+
92+
first_node = out1.producer()
93+
second_node = out2.producer()
94+
95+
# Ensure all inputs except the first are constants
96+
for input_ in first_node.inputs[1:] + second_node.inputs[1:]:
97+
if ir.convenience.get_const_tensor(input_) is None:
98+
return check_result.fail(f"{input_.name} is not a constant.")
99+
100+
# If scalars are required (Clip fusion), enforce scalar-ness
101+
if self.need_scalars and not self._is_scalar(input_.const_value.numpy()):
102+
return check_result.fail(f"{input_.name} is not a scalar.")
103+
104+
if self.need_scalars and self.check_bounds:
105+
# For Clip fusion in the case of Max(Min(X, upper_bound), lower_bound): check that lower_bound <= upper_bound
106+
lower_bound, upper_bound = self.compute_constants(first_node, second_node)
107+
if lower_bound[0].numpy() > upper_bound[0].numpy():
108+
return check_result.fail(
109+
f"Invalid bounds: lower bound ({lower_bound[0].numpy()}) is greater "
110+
f"than upper bound ({upper_bound[0].numpy()})."
111+
)
112+
113+
return check_result
114+
115+
116+
class FuseSuccessiveMin(_FuseMinMaxBase):
117+
"""Replaces ``Min(Min(X, c1, c2, ...), d1, d2, ...)`` with ``Min(X, fused_const)``.
118+
119+
Constraints:
120+
- All inputs except the first must be constants (from Constant nodes or initializers).
121+
"""
122+
123+
op_type: ClassVar = "Min"
124+
125+
def compute_constants(
126+
self,
127+
first_node: ir.Node,
128+
second_node: ir.Node,
129+
input_name: str = "",
130+
) -> list[tuple[ir.Tensor, str]]:
131+
inputs = first_node.inputs[1:] + second_node.inputs[1:]
132+
values = [input_.const_value.numpy() for input_ in inputs]
133+
return [(ir.tensor(functools.reduce(np.minimum, values)), f"{input_name}_min")]
134+
135+
def pattern(self, op, x):
136+
return op.Min(
137+
op.Min(x, _allow_other_inputs=True, _outputs=["out1"]),
138+
_allow_other_inputs=True,
139+
_outputs=["out2"],
140+
)
141+
142+
143+
class FuseSuccessiveMax(_FuseMinMaxBase):
144+
"""Replaces ``Max(Max(X, c1, c2, ...), d1, d2, ...)`` with ``Max(X, fused_const)``.
145+
146+
Constraints:
147+
- All inputs except the first must be constants (from Constant nodes or initializers).
148+
"""
149+
150+
op_type: ClassVar = "Max"
151+
152+
def compute_constants(
153+
self,
154+
first_node: ir.Node,
155+
second_node: ir.Node,
156+
input_name: str = "",
157+
) -> list[tuple[ir.Tensor, str]]:
158+
inputs = first_node.inputs[1:] + second_node.inputs[1:]
159+
values = [input_.const_value.numpy() for input_ in inputs]
160+
return [(ir.tensor(functools.reduce(np.maximum, values)), f"{input_name}_max")]
161+
162+
def pattern(self, op, x):
163+
return op.Max(
164+
op.Max(x, _allow_other_inputs=True, _outputs=["out1"]),
165+
_allow_other_inputs=True,
166+
_outputs=["out2"],
167+
)
168+
169+
170+
class FuseMaxMinToClip(_FuseMinMaxBase):
171+
"""Replaces ``Min(Max(X, lb1, lb2, ...), ub1, ub2, ...)`` with ``Clip(X, lb, ub)``.
172+
173+
Constraints:
174+
- All inputs except the first must be constants (from Constant nodes or initializers).
175+
- All constant inputs must be scalars.
176+
- The effective lower bound is ``max(lb1, lb2, ...)``.
177+
- The effective upper bound is ``min(ub1, ub2, ...)``.
178+
"""
179+
180+
op_type: ClassVar = "Clip"
181+
need_scalars: ClassVar = True
182+
183+
def compute_constants(
184+
self,
185+
first_node: ir.Node,
186+
second_node: ir.Node,
187+
input_name: str = "",
188+
) -> list[tuple[ir.Tensor, str]]:
189+
lower_bound = np.max([input_.const_value.numpy() for input_ in first_node.inputs[1:]])
190+
upper_bound = np.min([input_.const_value.numpy() for input_ in second_node.inputs[1:]])
191+
return [
192+
(ir.tensor(lower_bound), f"{input_name}_min"),
193+
(ir.tensor(upper_bound), f"{input_name}_max"),
194+
]
195+
196+
def pattern(self, op, x):
197+
return op.Min(
198+
op.Max(x, _allow_other_inputs=True, _outputs=["out1"]),
199+
_allow_other_inputs=True,
200+
_outputs=["out2"],
201+
)
202+
203+
204+
class FuseMinMaxToClip(_FuseMinMaxBase):
205+
"""Replaces ``Max(Min(X, ub1, ub2, ...), lb1, lb2, ...)`` with ``Clip(X, lb, ub)``.
206+
207+
Constraints:
208+
- All inputs except the first must be constants (from Constant nodes or initializers).
209+
- All constant inputs must be scalars.
210+
- The effective lower bound is ``max(lb1, lb2, ...)``.
211+
- The effective upper bound is ``min(ub1, ub2, ...)``.
212+
- Requires ``lower_bound <= upper_bound``.
213+
"""
214+
215+
op_type: ClassVar = "Clip"
216+
need_scalars: ClassVar = True
217+
check_bounds: ClassVar = True
218+
219+
def compute_constants(
220+
self,
221+
first_node: ir.Node,
222+
second_node: ir.Node,
223+
input_name: str = "",
224+
) -> list[tuple[ir.Tensor, str]]:
225+
upper_bound = np.min([input_.const_value.numpy() for input_ in first_node.inputs[1:]])
226+
lower_bound = np.max([input_.const_value.numpy() for input_ in second_node.inputs[1:]])
227+
return [
228+
(ir.tensor(lower_bound), f"{input_name}_min"),
229+
(ir.tensor(upper_bound), f"{input_name}_max"),
230+
]
231+
232+
def pattern(self, op, x):
233+
return op.Max(
234+
op.Min(x, _allow_other_inputs=True, _outputs=["out1"]),
235+
_allow_other_inputs=True,
236+
_outputs=["out2"],
237+
)
238+
239+
240+
min_min_rule = FuseSuccessiveMin().rule()
241+
max_max_rule = FuseSuccessiveMax().rule()
242+
min_max_rule = FuseMinMaxToClip().rule()
243+
max_min_rule = FuseMaxMinToClip().rule()
244+
245+
246+
rules = RewriteRuleSet(
247+
[
248+
min_min_rule,
249+
max_max_rule,
250+
min_max_rule,
251+
max_min_rule,
252+
]
253+
)

0 commit comments

Comments
 (0)