Skip to content

Commit 717fa45

Browse files
committed
review(min_max_to_clip): check bounds only for Max(Min) pattern
1 parent 105b904 commit 717fa45

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

onnxscript/rewriter/rules/common/_min_max_to_clip.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
* All constant inputs must be scalars.
1515
* The effective lower bound is the maximum of all lower-bound constants.
1616
* The effective upper bound is the minimum of all upper-bound constants.
17-
* The rule applies only if lower_bound ≤ upper_bound.
17+
18+
For the case of Max(Min(X, upper_bound), lower_bound):
19+
* The rule applies only if lower_bound ≤ upper_bound.
1820
1921
General constraints:
2022
- The first input may be any tensor.
@@ -38,9 +40,11 @@ class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC):
3840
Constraints:
3941
- All inputs except the first must be constants (from Constant nodes or initializers).
4042
- If ``need_scalars`` is True (Clip fusion), all constants must be scalars.
43+
- If ``check_bounds`` is True (Clip fusion in the pattern Max(Min(X, upper_bound), lower_bound)), lower_bound ≤ upper_bound.
4144
"""
4245

4346
need_scalars: ClassVar = False
47+
check_bounds: ClassVar = False
4448

4549
@abc.abstractmethod
4650
def compute_constants(
@@ -75,7 +79,8 @@ def check(self, context, out1, out2, **_):
7579
- These inputs (except the first) must be constant values (from Constant nodes or initializers).
7680
- In the case of Min(Max) and Max(Min) patterns:
7781
* All inputs must be scalars (as Clip requires scalars).
78-
* The lower bound must be less than or equal to the upper bound.
82+
For Max(Min) pattern:
83+
* The lower bound must be less than or equal to the upper bound.
7984
8085
Returns:
8186
MatchResult:
@@ -96,8 +101,8 @@ def check(self, context, out1, out2, **_):
96101
if self.need_scalars and not self._is_scalar(input_.const_value.numpy()):
97102
return check_result.fail(f"{input_.name} is not a scalar.")
98103

99-
if self.need_scalars:
100-
# For Clip fusion: check that lower_bound <= upper_bound
104+
if self.need_scalars and self.check_bounds:
105+
# For Clip fusion in the case of Max(Min(X, upper_bound), lower_bound): check that lower_bound <= upper_bound
101106
lower_bound, upper_bound = self.compute_constants(first_node, second_node)
102107
if lower_bound[0].numpy() > upper_bound[0].numpy():
103108
return check_result.fail(
@@ -170,7 +175,6 @@ class FuseMaxMinToClip(_FuseMinMaxBase):
170175
- All constant inputs must be scalars.
171176
- The effective lower bound is ``max(lb1, lb2, ...)``.
172177
- The effective upper bound is ``min(ub1, ub2, ...)``.
173-
- Requires ``lower_bound <= upper_bound``.
174178
"""
175179

176180
op_type: ClassVar = "Clip"
@@ -210,6 +214,7 @@ class FuseMinMaxToClip(_FuseMinMaxBase):
210214

211215
op_type: ClassVar = "Clip"
212216
need_scalars: ClassVar = True
217+
check_bounds: ClassVar = True
213218

214219
def compute_constants(
215220
self,

onnxscript/rewriter/rules/common/_min_max_to_clip_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,7 @@ def test_successful_max_min_to_clip_graph_inputs_as_constants(self):
311311
""")
312312
self.run_test(base_model, expected_op_types=["Clip"])
313313

314-
def test_failure_max_min_to_clip_invalid_bounds(self):
315-
"""Min node should have the max value and Max node should have the min value."""
314+
def test_successful_max_min_to_clip_check_bounds(self):
316315
base_model = ir.from_onnx_text("""
317316
< ir_version: 10, opset_import: ["" : 20] >
318317
test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y)
@@ -322,9 +321,7 @@ def test_failure_max_min_to_clip_invalid_bounds(self):
322321
Y = Min(x1, min)
323322
}
324323
""")
325-
self.run_failed_condition_test(
326-
base_model, fuse_successive_max_min_rule, "Invalid bounds:"
327-
)
324+
self.run_test(base_model, expected_op_types=["Clip"])
328325

329326
def test_failure_fuse_max_min_to_clip_non_constant(self):
330327
model = ir.from_onnx_text("""

0 commit comments

Comments
 (0)