1414 * All constant inputs must be scalars.
1515 * The effective lower bound is the maximum of all lower-bound constants.
1616 * The effective upper bound is the minimum of all upper-bound constants.
17- * The rule applies only if lower_bound ≤ upper_bound.
17+
18+ For the case of Max(Min(X, upper_bound), lower_bound):
19+ * The rule applies only if lower_bound ≤ upper_bound.
1820
1921General constraints:
2022 - The first input may be any tensor.
@@ -38,9 +40,11 @@ class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC):
3840 Constraints:
3941 - All inputs except the first must be constants (from Constant nodes or initializers).
4042 - If ``need_scalars`` is True (Clip fusion), all constants must be scalars.
43+ - If ``check_bounds`` is True (Clip fusion in the pattern Max(Min(X, upper_bound), lower_bound)), lower_bound ≤ upper_bound.
4144 """
4245
4346 need_scalars : ClassVar = False
47+ check_bounds : ClassVar = False
4448
4549 @abc .abstractmethod
4650 def compute_constants (
@@ -75,7 +79,8 @@ def check(self, context, out1, out2, **_):
7579 - These inputs (except the first) must be constant values (from Constant nodes or initializers).
7680 - In the case of Min(Max) and Max(Min) patterns:
7781 * All inputs must be scalars (as Clip requires scalars).
78- * The lower bound must be less than or equal to the upper bound.
82+ For Max(Min) pattern:
83+ * The lower bound must be less than or equal to the upper bound.
7984
8085 Returns:
8186 MatchResult:
@@ -96,8 +101,8 @@ def check(self, context, out1, out2, **_):
96101 if self .need_scalars and not self ._is_scalar (input_ .const_value .numpy ()):
97102 return check_result .fail (f"{ input_ .name } is not a scalar." )
98103
99- if self .need_scalars :
100- # For Clip fusion: check that lower_bound <= upper_bound
104+ if self .need_scalars and self . check_bounds :
105+ # For Clip fusion in the case of Max(Min(X, upper_bound), lower_bound) : check that lower_bound <= upper_bound
101106 lower_bound , upper_bound = self .compute_constants (first_node , second_node )
102107 if lower_bound [0 ].numpy () > upper_bound [0 ].numpy ():
103108 return check_result .fail (
@@ -170,7 +175,6 @@ class FuseMaxMinToClip(_FuseMinMaxBase):
170175 - All constant inputs must be scalars.
171176 - The effective lower bound is ``max(lb1, lb2, ...)``.
172177 - The effective upper bound is ``min(ub1, ub2, ...)``.
173- - Requires ``lower_bound <= upper_bound``.
174178 """
175179
176180 op_type : ClassVar = "Clip"
@@ -210,6 +214,7 @@ class FuseMinMaxToClip(_FuseMinMaxBase):
210214
211215 op_type : ClassVar = "Clip"
212216 need_scalars : ClassVar = True
217+ check_bounds : ClassVar = True
213218
214219 def compute_constants (
215220 self ,
0 commit comments