Skip to content

Commit c21bb05

Browse files
committed
[Rewriter]: add fusion rules for successive Min/Max patterns
- Min(Min(X)) -> Min(X) - Max(Max(X)) -> Max(X) - Min(Max(X)) -> Clip(X) - Max(Min(X)) -> Clip(X)
1 parent fe152d4 commit c21bb05

2 files changed

Lines changed: 595 additions & 0 deletions

File tree

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Does the following transformation:
4+
- Min(Min(X)) -> Min(X)
5+
- Max(Max(X)) -> Max(X)
6+
- Min(Max(X)) -> Clip(X)
7+
- Max(Min(X)) -> Clip(X)
8+
"""
9+
10+
import abc
11+
import functools
12+
from typing import ClassVar
13+
14+
import numpy as np
15+
import onnx_ir as ir
16+
17+
from onnxscript.rewriter._basics import MatchResult
18+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
19+
20+
21+
class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC):
22+
need_scalars: ClassVar = False
23+
24+
@abc.abstractmethod
25+
def compute_constants(
26+
self,
27+
first_node: ir.Node,
28+
second_node: ir.Node,
29+
input_name: str = "",
30+
) -> list[tuple[ir.Tensor, str]]: ...
31+
32+
def rewrite(self, op, x, out1, out2):
33+
first_node = out1.producer()
34+
second_node = out2.producer()
35+
36+
# Compute new constants for the fused op
37+
constants = self.compute_constants(first_node, second_node, x.name)
38+
39+
initializers = [op.initializer(constant, name=name) for constant, name in constants]
40+
41+
return op.op(
42+
self.op_type,
43+
inputs=[x, *initializers],
44+
)
45+
46+
def _is_scalar(self, v: np.ndarray) -> bool:
47+
return np.isscalar(v) or v.shape == () or (v.shape == (1,))
48+
49+
def check(self, context, out1, out2, **_):
50+
"""Condition to check if we need to replace the pattern.
51+
52+
Conditions:
53+
- The min and max input nodes must not be graph inputs.
54+
- These inputs (except the first) must be constant values (from Constant nodes or initializers).
55+
- In the case of Min(Max) and Max(Min) patterns:
56+
* All inputs must be scalars (as Clip requires scalars).
57+
* The lower bound must be less than or equal to the upper bound.
58+
59+
Returns:
60+
MatchResult:
61+
Success if we need to replace the pattern, Failure otherwise.
62+
"""
63+
del context # Not used
64+
check_result = MatchResult()
65+
66+
first_node = out1.producer()
67+
second_node = out2.producer()
68+
69+
# Ensure all inputs except the first are constants
70+
for input_ in first_node.inputs[1:] + second_node.inputs[1:]:
71+
if input_.is_graph_input():
72+
return check_result.fail(f"{input_.name} is a graph input.")
73+
74+
if ir.convenience.get_const_tensor(input_) is None:
75+
return check_result.fail(f"{input_.name} is not a constant.")
76+
77+
# If scalars are required (Clip fusion), enforce scalar-ness
78+
if self.need_scalars and not self._is_scalar(input_.const_value.numpy()):
79+
return check_result.fail(f"{input_.name} is not a scalar.")
80+
81+
if self.need_scalars:
82+
# For Clip fusion: check that lower_bound <= upper_bound
83+
lower_bound, upper_bound = self.compute_constants(first_node, second_node)
84+
if lower_bound[0].numpy() > upper_bound[0].numpy():
85+
return check_result.fail(
86+
f"Invalid bounds: lower bound ({lower_bound[0].numpy()}) is greater "
87+
f"than upper bound ({upper_bound[0].numpy()})."
88+
)
89+
90+
return check_result
91+
92+
93+
class FuseSuccessiveMin(_FuseMinMaxBase):
94+
"""Replaces ``Min(Min(X))`` with ``Min(X)``."""
95+
96+
op_type: ClassVar = "Min"
97+
98+
def compute_constants(
99+
self,
100+
first_node: ir.Node,
101+
second_node: ir.Node,
102+
input_name: str = "",
103+
) -> list[tuple[ir.Tensor, str]]:
104+
inputs = first_node.inputs[1:] + second_node.inputs[1:]
105+
values = [input_.const_value.numpy() for input_ in inputs]
106+
return [(ir.tensor(functools.reduce(np.minimum, values)), f"{input_name}_min")]
107+
108+
def pattern(self, op, x):
109+
return op.Min(
110+
op.Min(x, _allow_other_inputs=True, _outputs=["out1"]),
111+
_allow_other_inputs=True,
112+
_outputs=["out2"],
113+
)
114+
115+
116+
class FuseSuccessiveMax(_FuseMinMaxBase):
117+
"""Replaces ``Max(Max(X))`` with ``Max(X)``."""
118+
119+
op_type: ClassVar = "Max"
120+
121+
def compute_constants(
122+
self,
123+
first_node: ir.Node,
124+
second_node: ir.Node,
125+
input_name: str = "",
126+
) -> list[tuple[ir.Tensor, str]]:
127+
inputs = first_node.inputs[1:] + second_node.inputs[1:]
128+
values = [input_.const_value.numpy() for input_ in inputs]
129+
return [(ir.tensor(functools.reduce(np.maximum, values)), f"{input_name}_max")]
130+
131+
def pattern(self, op, x):
132+
return op.Max(
133+
op.Max(x, _allow_other_inputs=True, _outputs=["out1"]),
134+
_allow_other_inputs=True,
135+
_outputs=["out2"],
136+
)
137+
138+
139+
class FuseMaxMinToClip(_FuseMinMaxBase):
140+
"""Replaces ``Min(Max(X))`` with ``Clip(X)``."""
141+
142+
op_type: ClassVar = "Clip"
143+
need_scalars: ClassVar = True
144+
145+
def compute_constants(
146+
self,
147+
first_node: ir.Node,
148+
second_node: ir.Node,
149+
input_name: str = "",
150+
) -> list[tuple[ir.Tensor, str]]:
151+
lower_bound = np.max([input_.const_value.numpy() for input_ in first_node.inputs[1:]])
152+
upper_bound = np.min([input_.const_value.numpy() for input_ in second_node.inputs[1:]])
153+
return [
154+
(ir.tensor(lower_bound), f"{input_name}_min"),
155+
(ir.tensor(upper_bound), f"{input_name}_max"),
156+
]
157+
158+
def pattern(self, op, x):
159+
return op.Min(
160+
op.Max(x, _allow_other_inputs=True, _outputs=["out1"]),
161+
_allow_other_inputs=True,
162+
_outputs=["out2"],
163+
)
164+
165+
166+
class FuseMinMaxToClip(_FuseMinMaxBase):
167+
"""Replaces ``Max(Min(X))`` with ``Clip(X)``."""
168+
169+
op_type: ClassVar = "Clip"
170+
need_scalars: ClassVar = True
171+
172+
def compute_constants(
173+
self,
174+
first_node: ir.Node,
175+
second_node: ir.Node,
176+
input_name: str = "",
177+
) -> list[tuple[ir.Tensor, str]]:
178+
upper_bound = np.min([input_.const_value.numpy() for input_ in first_node.inputs[1:]])
179+
lower_bound = np.max([input_.const_value.numpy() for input_ in second_node.inputs[1:]])
180+
return [
181+
(ir.tensor(lower_bound), f"{input_name}_min"),
182+
(ir.tensor(upper_bound), f"{input_name}_max"),
183+
]
184+
185+
def pattern(self, op, x):
186+
return op.Max(
187+
op.Min(x, _allow_other_inputs=True, _outputs=["out1"]),
188+
_allow_other_inputs=True,
189+
_outputs=["out2"],
190+
)
191+
192+
193+
fuse_successive_min_rule = FuseSuccessiveMin().rule()
194+
fuse_successive_max_rule = FuseSuccessiveMax().rule()
195+
fuse_successive_min_max_rule = FuseMinMaxToClip().rule()
196+
fuse_successive_max_min_rule = FuseMaxMinToClip().rule()
197+
198+
199+
def min_max_to_clip_rules() -> RewriteRuleSet:
200+
"""Returns a set of rewrite rules that fuse successive Min/Max nodes.
201+
202+
Returns:
203+
RewriteRuleSet
204+
"""
205+
206+
return RewriteRuleSet(
207+
[
208+
fuse_successive_min_rule,
209+
fuse_successive_max_rule,
210+
fuse_successive_min_max_rule,
211+
fuse_successive_max_min_rule,
212+
]
213+
)

0 commit comments

Comments
 (0)