Skip to content

Commit 78e9033

Browse files
committed
[Rewriter]: add fusion rules for successive Min/Max patterns
- Min(Min(X)) -> Min(X) - Max(Max(X)) -> Max(X) - Min(Max(X)) -> Clip(X) - Max(Min(X)) -> Clip(X)
1 parent 3af04e9 commit 78e9033

3 files changed

Lines changed: 642 additions & 0 deletions

File tree

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
collapse_slices,
3030
fuse_pad_into_conv,
3131
fuse_relus_clips,
32+
min_max_to_clip,
3233
no_op,
3334
pattern,
3435
redundant_scatter_nd,
@@ -47,6 +48,7 @@
4748
*broadcast_to_matmul.rules.rules,
4849
*cast_constant_of_shape.rules.rules,
4950
*collapse_slices.rules.rules,
51+
*min_max_to_clip.min_max_to_clip_rules().rules,
5052
*fuse_relus_clips.fuse_relus_clips_rules().rules,
5153
*basic_rules.basic_optimization_rules().rules,
5254
*redundant_scatter_nd.rules.rules,
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fuses successive Min/Max patterns in ONNX graphs.
4+
5+
Supported transformations:
6+
- Min(Min(X, c1, c2, ...), d1, d2, ...) → Min(X, fused_const)
7+
- Max(Max(X, c1, c2, ...), d1, d2, ...) → Max(X, fused_const)
8+
- Min(Max(X, lb1, lb2, ...), ub1, ub2, ...) → Clip(X, lb, ub)
9+
- Max(Min(X, ub1, ub2, ...), lb1, lb2, ...) → Clip(X, lb, ub)
10+
11+
Where:
12+
- fused_const is the reduction (min or max) over all constant inputs.
13+
- For Clip fusion:
14+
* All constant inputs must be scalars.
15+
* The effective lower bound is the maximum of all lower-bound constants.
16+
* The effective upper bound is the minimum of all upper-bound constants.
17+
* The rule applies only if lower_bound ≤ upper_bound.
18+
19+
General constraints:
20+
- The first input may be any tensor.
21+
- All other inputs must be constant tensors (from Constant nodes or initializers).
22+
"""
23+
24+
import abc
25+
import functools
26+
from typing import ClassVar
27+
28+
import numpy as np
29+
import onnx_ir as ir
30+
31+
from onnxscript.rewriter._basics import MatchResult
32+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
33+
34+
35+
class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC):
36+
"""Base class for Min/Max fusion rewrites.
37+
38+
Constraints:
39+
- All inputs except the first must be constants (from Constant nodes or initializers).
40+
- If ``need_scalars`` is True (Clip fusion), all constants must be scalars.
41+
"""
42+
43+
need_scalars: ClassVar = False
44+
45+
@abc.abstractmethod
46+
def compute_constants(
47+
self,
48+
first_node: ir.Node,
49+
second_node: ir.Node,
50+
input_name: str = "",
51+
) -> list[tuple[ir.Tensor, str]]: ...
52+
53+
def rewrite(self, op, x, out1, out2):
54+
first_node = out1.producer()
55+
second_node = out2.producer()
56+
57+
# Compute new constants for the fused op
58+
constants = self.compute_constants(first_node, second_node, x.name)
59+
60+
initializers = [op.initializer(constant, name=name) for constant, name in constants]
61+
62+
return op.op(
63+
self.op_type,
64+
inputs=[x, *initializers],
65+
)
66+
67+
def _is_scalar(self, v: np.ndarray) -> bool:
68+
return np.isscalar(v) or np.size(v) == 1
69+
70+
def check(self, context, out1, out2, **_):
71+
"""Condition to check if we need to replace the pattern.
72+
73+
Conditions:
74+
- The min and max input nodes must not be graph inputs.
75+
- These inputs (except the first) must be constant values (from Constant nodes or initializers).
76+
- In the case of Min(Max) and Max(Min) patterns:
77+
* All inputs must be scalars (as Clip requires scalars).
78+
* The lower bound must be less than or equal to the upper bound.
79+
80+
Returns:
81+
MatchResult:
82+
Success if we need to replace the pattern, Failure otherwise.
83+
"""
84+
del context # Not used
85+
check_result = MatchResult()
86+
87+
first_node = out1.producer()
88+
second_node = out2.producer()
89+
90+
# Ensure all inputs except the first are constants
91+
for input_ in first_node.inputs[1:] + second_node.inputs[1:]:
92+
if input_.is_graph_input():
93+
return check_result.fail(f"{input_.name} is a graph input.")
94+
95+
if ir.convenience.get_const_tensor(input_) is None:
96+
return check_result.fail(f"{input_.name} is not a constant.")
97+
98+
# If scalars are required (Clip fusion), enforce scalar-ness
99+
if self.need_scalars and not self._is_scalar(input_.const_value.numpy()):
100+
return check_result.fail(f"{input_.name} is not a scalar.")
101+
102+
if self.need_scalars:
103+
# For Clip fusion: check that lower_bound <= upper_bound
104+
lower_bound, upper_bound = self.compute_constants(first_node, second_node)
105+
if lower_bound[0].numpy() > upper_bound[0].numpy():
106+
return check_result.fail(
107+
f"Invalid bounds: lower bound ({lower_bound[0].numpy()}) is greater "
108+
f"than upper bound ({upper_bound[0].numpy()})."
109+
)
110+
111+
return check_result
112+
113+
114+
class FuseSuccessiveMin(_FuseMinMaxBase):
115+
"""Replaces ``Min(Min(X, c1, c2, ...), d1, d2, ...)`` with ``Min(X, fused_const)``.
116+
117+
Constraints:
118+
- All inputs except the first must be constants (from Constant nodes or initializers).
119+
"""
120+
121+
op_type: ClassVar = "Min"
122+
123+
def compute_constants(
124+
self,
125+
first_node: ir.Node,
126+
second_node: ir.Node,
127+
input_name: str = "",
128+
) -> list[tuple[ir.Tensor, str]]:
129+
inputs = first_node.inputs[1:] + second_node.inputs[1:]
130+
values = [input_.const_value.numpy() for input_ in inputs]
131+
return [(ir.tensor(functools.reduce(np.minimum, values)), f"{input_name}_min")]
132+
133+
def pattern(self, op, x):
134+
return op.Min(
135+
op.Min(x, _allow_other_inputs=True, _outputs=["out1"]),
136+
_allow_other_inputs=True,
137+
_outputs=["out2"],
138+
)
139+
140+
141+
class FuseSuccessiveMax(_FuseMinMaxBase):
142+
"""Replaces ``Max(Max(X, c1, c2, ...), d1, d2, ...)`` with ``Max(X, fused_const)``.
143+
144+
Constraints:
145+
- All inputs except the first must be constants (from Constant nodes or initializers).
146+
"""
147+
148+
op_type: ClassVar = "Max"
149+
150+
def compute_constants(
151+
self,
152+
first_node: ir.Node,
153+
second_node: ir.Node,
154+
input_name: str = "",
155+
) -> list[tuple[ir.Tensor, str]]:
156+
inputs = first_node.inputs[1:] + second_node.inputs[1:]
157+
values = [input_.const_value.numpy() for input_ in inputs]
158+
return [(ir.tensor(functools.reduce(np.maximum, values)), f"{input_name}_max")]
159+
160+
def pattern(self, op, x):
161+
return op.Max(
162+
op.Max(x, _allow_other_inputs=True, _outputs=["out1"]),
163+
_allow_other_inputs=True,
164+
_outputs=["out2"],
165+
)
166+
167+
168+
class FuseMaxMinToClip(_FuseMinMaxBase):
169+
"""Replaces ``Min(Max(X, lb1, lb2, ...), ub1, ub2, ...)`` with ``Clip(X, lb, ub)``.
170+
171+
Constraints:
172+
- All inputs except the first must be constants (from Constant nodes or initializers).
173+
- All constant inputs must be scalars.
174+
- The effective lower bound is ``max(lb1, lb2, ...)``.
175+
- The effective upper bound is ``min(ub1, ub2, ...)``.
176+
- Requires ``lower_bound <= upper_bound``.
177+
"""
178+
179+
op_type: ClassVar = "Clip"
180+
need_scalars: ClassVar = True
181+
182+
def compute_constants(
183+
self,
184+
first_node: ir.Node,
185+
second_node: ir.Node,
186+
input_name: str = "",
187+
) -> list[tuple[ir.Tensor, str]]:
188+
lower_bound = np.max([input_.const_value.numpy() for input_ in first_node.inputs[1:]])
189+
upper_bound = np.min([input_.const_value.numpy() for input_ in second_node.inputs[1:]])
190+
return [
191+
(ir.tensor(lower_bound), f"{input_name}_min"),
192+
(ir.tensor(upper_bound), f"{input_name}_max"),
193+
]
194+
195+
def pattern(self, op, x):
196+
return op.Min(
197+
op.Max(x, _allow_other_inputs=True, _outputs=["out1"]),
198+
_allow_other_inputs=True,
199+
_outputs=["out2"],
200+
)
201+
202+
203+
class FuseMinMaxToClip(_FuseMinMaxBase):
204+
"""Replaces ``Max(Min(X, ub1, ub2, ...), lb1, lb2, ...)`` with ``Clip(X, lb, ub)``.
205+
206+
Constraints:
207+
- All inputs except the first must be constants (from Constant nodes or initializers).
208+
- All constant inputs must be scalars.
209+
- The effective lower bound is ``max(lb1, lb2, ...)``.
210+
- The effective upper bound is ``min(ub1, ub2, ...)``.
211+
- Requires ``lower_bound <= upper_bound``.
212+
"""
213+
214+
op_type: ClassVar = "Clip"
215+
need_scalars: ClassVar = True
216+
217+
def compute_constants(
218+
self,
219+
first_node: ir.Node,
220+
second_node: ir.Node,
221+
input_name: str = "",
222+
) -> list[tuple[ir.Tensor, str]]:
223+
upper_bound = np.min([input_.const_value.numpy() for input_ in first_node.inputs[1:]])
224+
lower_bound = np.max([input_.const_value.numpy() for input_ in second_node.inputs[1:]])
225+
return [
226+
(ir.tensor(lower_bound), f"{input_name}_min"),
227+
(ir.tensor(upper_bound), f"{input_name}_max"),
228+
]
229+
230+
def pattern(self, op, x):
231+
return op.Max(
232+
op.Min(x, _allow_other_inputs=True, _outputs=["out1"]),
233+
_allow_other_inputs=True,
234+
_outputs=["out2"],
235+
)
236+
237+
238+
fuse_successive_min_rule = FuseSuccessiveMin().rule()
239+
fuse_successive_max_rule = FuseSuccessiveMax().rule()
240+
fuse_successive_min_max_rule = FuseMinMaxToClip().rule()
241+
fuse_successive_max_min_rule = FuseMaxMinToClip().rule()
242+
243+
244+
def min_max_to_clip_rules() -> RewriteRuleSet:
245+
"""Returns a set of rewrite rules that fuse successive Min/Max nodes.
246+
247+
Returns:
248+
RewriteRuleSet
249+
"""
250+
251+
return RewriteRuleSet(
252+
[
253+
fuse_successive_min_rule,
254+
fuse_successive_max_rule,
255+
fuse_successive_min_max_rule,
256+
fuse_successive_max_min_rule,
257+
]
258+
)

0 commit comments

Comments
 (0)