diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 1d07e9f5af..232750af78 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -37,6 +37,7 @@ _collapse_slices, _fuse_pad_into_conv, _fuse_relus_clips, + _min_max_to_clip, _no_op, _redundant_scatter_nd, ) @@ -47,6 +48,7 @@ *_broadcast_to_matmul.rules, *_cast_constant_of_shape.rules, *_collapse_slices.rules, + *_min_max_to_clip.rules, *_fuse_relus_clips.rules, *_basic_rules.basic_optimization_rules(), *_redundant_scatter_nd.rules, diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 752e3c9430..e86b46cd7b 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -15,6 +15,10 @@ "fuse_batchnorm_into_gemm_rule", "fuse_pad_into_conv_integer_rule", "fuse_pad_into_conv_rule", + "min_min_rule", + "max_max_rule", + "min_max_rule", + "max_min_rule", "gemm_to_matmul_add_rule", "matmul_add_to_gemm_rule", "mul_by_1_rule", @@ -89,6 +93,12 @@ transpose_ab_matmul_add_to_gemm_rule, transpose_b_matmul_add_to_gemm_rule, ) +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, +) from onnxscript.rewriter.rules.common._no_op import ( add_0_rule, div_by_1_rule, diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip.py b/onnxscript/rewriter/rules/common/_min_max_to_clip.py new file mode 100644 index 0000000000..88ae495dbc --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses successive Min/Max patterns in ONNX graphs. + +Supported transformations: +- Min(Min(X, c1, c2, ...), d1, d2, ...) → Min(X, fused_const) +- Max(Max(X, c1, c2, ...), d1, d2, ...) → Max(X, fused_const) +- Min(Max(X, lb1, lb2, ...), ub1, ub2, ...) → Clip(X, lb, ub) +- Max(Min(X, ub1, ub2, ...), lb1, lb2, ...) → Clip(X, lb, ub) + +Where: + - fused_const is the reduction (min or max) over all constant inputs. + - For Clip fusion: + * All constant inputs must be scalars. + * The effective lower bound is the maximum of all lower-bound constants. + * The effective upper bound is the minimum of all upper-bound constants. + + For the case of Max(Min(X, upper_bound), lower_bound): + * The rule applies only if lower_bound ≤ upper_bound. + +General constraints: + - The first input may be any tensor. + - All other inputs must be constant tensors (from Constant nodes or initializers). +""" + +import abc +import functools +from typing import ClassVar + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC): + """Base class for Min/Max fusion rewrites. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - If ``need_scalars`` is True (Clip fusion), all constants must be scalars. + - If ``check_bounds`` is True (Clip fusion in the pattern Max(Min(X, upper_bound), lower_bound)), lower_bound ≤ upper_bound. + """ + + need_scalars: ClassVar = False + check_bounds: ClassVar = False + + @abc.abstractmethod + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: ... + + def rewrite(self, op, x, out1, out2): + first_node = out1.producer() + second_node = out2.producer() + + # Compute new constants for the fused op + constants = self.compute_constants(first_node, second_node, x.name) + + initializers = [op.initializer(constant, name=name) for constant, name in constants] + + return op.op( + self.op_type, + inputs=[x, *initializers], + ) + + def _is_scalar(self, v: np.ndarray) -> bool: + return np.isscalar(v) or np.size(v) == 1 + + def check(self, context, out1, out2, **_): + """Condition to check if we need to replace the pattern. + + Conditions: + - The min and max input nodes must not be graph inputs. + - These inputs (except the first) must be constant values (from Constant nodes or initializers). + - In the case of Min(Max) and Max(Min) patterns: + * All inputs must be scalars (as Clip requires scalars). + For Max(Min) pattern: + * The lower bound must be less than or equal to the upper bound. + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Not used + check_result = MatchResult() + + first_node = out1.producer() + second_node = out2.producer() + + # Ensure all inputs except the first are constants + for input_ in first_node.inputs[1:] + second_node.inputs[1:]: + if ir.convenience.get_const_tensor(input_) is None: + return check_result.fail(f"{input_.name} is not a constant.") + + # If scalars are required (Clip fusion), enforce scalar-ness + if self.need_scalars and not self._is_scalar(input_.const_value.numpy()): + return check_result.fail(f"{input_.name} is not a scalar.") + + if self.need_scalars and self.check_bounds: + # For Clip fusion in the case of Max(Min(X, upper_bound), lower_bound): check that lower_bound <= upper_bound + lower_bound, upper_bound = self.compute_constants(first_node, second_node) + if lower_bound[0].numpy() > upper_bound[0].numpy(): + return check_result.fail( + f"Invalid bounds: lower bound ({lower_bound[0].numpy()}) is greater " + f"than upper bound ({upper_bound[0].numpy()})." + ) + + return check_result + + +class FuseSuccessiveMin(_FuseMinMaxBase): + """Replaces ``Min(Min(X, c1, c2, ...), d1, d2, ...)`` with ``Min(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Min" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.minimum, values)), f"{input_name}_min")] + + def pattern(self, op, x): + return op.Min( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseSuccessiveMax(_FuseMinMaxBase): + """Replaces ``Max(Max(X, c1, c2, ...), d1, d2, ...)`` with ``Max(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Max" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.maximum, values)), f"{input_name}_max")] + + def pattern(self, op, x): + return op.Max( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMaxMinToClip(_FuseMinMaxBase): + """Replaces ``Min(Max(X, lb1, lb2, ...), ub1, ub2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + lower_bound = np.max([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + upper_bound = np.min([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Min( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMinMaxToClip(_FuseMinMaxBase): + """Replaces ``Max(Min(X, ub1, ub2, ...), lb1, lb2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + - Requires ``lower_bound <= upper_bound``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + check_bounds: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + upper_bound = np.min([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + lower_bound = np.max([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Max( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +min_min_rule = FuseSuccessiveMin().rule() +max_max_rule = FuseSuccessiveMax().rule() +min_max_rule = FuseMinMaxToClip().rule() +max_min_rule = FuseMaxMinToClip().rule() + + +rules = RewriteRuleSet( + [ + min_min_rule, + max_max_rule, + min_max_rule, + max_min_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py new file mode 100644 index 0000000000..dd09078a9e --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py @@ -0,0 +1,367 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, + rules, +) + + +class _TestMinMaxToClipBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250817) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = rules.apply_to_model(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = ( + self.rng.integers( + low=-10, + high=10, + size=(2, *updated_model.graph.inputs[0].shape[1:]), + dtype=np.int32, + ), + ) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class TestFuseSuccessiveMinOrMax(_TestMinMaxToClipBase): + @parameterized.expand( + [ + ("int32_min", "int32", "Min"), + ("int32_max", "int32", "Max"), + ("float32_min", "float", "Min"), + ("float32_max", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14, 17] X) => ({dtype} [N, ?, ?, ?] Y) + <{dtype}[1] cst1 = {{3}}, {dtype}[1] cst2 = {{6}}> + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min_multi", "int32", "Min"), + ("int32_max_multi", "int32", "Max"), + ("float32_min_multi", "float", "Min"), + ("float32_max_multi", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_multiple_inputs(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 3, 3] X) => ({dtype}[N, 3, 3] Y) + <{dtype}[3] cst1 = {{2, 5, 8}}, + {dtype}[1] cst2 = {{4}}, + {dtype}[3] cst3 = {{3, 1, -6}}, + {dtype}[1] cst4 = {{10}}, + {dtype}[3] cst5 = {{-2, 7, 9}}, + {dtype}[1] cst6 = {{0}}, + {dtype}[3] cst7 = {{11, -3, 4}}> + {{ + x1 = {op_type}(X, cst1, cst2, cst3, cst4) + Y = {op_type}(x1, cst5, cst6, cst7) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min", "Min"), + ("int32_max", "Max"), + ("float32_min", "Min"), + ("float32_max", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + cst2 = Constant() + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=["Constant", op_type]) + + @parameterized.expand( + [ + ("min_nonconst", "Min", min_min_rule), + ("max_nonconst", "Max", max_max_rule), + ] + ) + def test_failure_fuse_successive_min_or_max_non_constant(self, _, op_type, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] Y) + + {{ + cst1 = ReduceMean(X) + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is not a constant.") + + @parameterized.expand( + [ + ("min_graph_input", "Min"), + ("max_graph_input", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_graph_inputs_as_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] cst1, float[1] cst2) => (float[N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type]) + + +class TestMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_min_max_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_min_max_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + max = Constant() + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_min_max_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_min_max_to_clip_invalid_bounds(self): + """Min node should have the max value and Max node should have the min value.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "Invalid bounds:") + + def test_failure_fuse_min_max_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(model, min_max_rule, "is not a constant.") + + def test_failure_min_max_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "is not a scalar") + + +class TestMaxMinToClip(_TestMinMaxToClipBase): + def test_successful_max_min_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + min = Constant() + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_max_min_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_check_bounds(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_fuse_max_min_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_failed_condition_test(model, max_min_rule, "is not a constant.") + + def test_failure_max_min_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Max(X, min) + Y = Min(x1, max) + } + """) + self.run_failed_condition_test(base_model, max_min_rule, "is not a scalar") + + +class TestIntegrationMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_full_chain_fusion(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min1) + x2 = Min(x1, min2) + x3 = Max(x2, max1) + x4 = Max(x3, max2) + x5 = Min(x4, min3) + x6 = Max(x5, max3) + Y = Min(x6, min4) + } + """) + self.run_test(model, expected_op_types=["Clip", "Clip", "Clip"]) + + +if __name__ == "__main__": + unittest.main()