diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 0b01bade72..14ed3587f3 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -2,11 +2,13 @@ # Licensed under the MIT License. __all__ = [ "add_0_rule", + "affine_conv_fusion_rule", "cast_cast_rule", "cast_constant_of_shape_rule", "cast_constant_of_shape_without_value_rule", "collapse_slice_rule", "collapse_slice2_rule", + "conv_affine_fusion_rule", "div_by_1_rule", "dropout_inference_rule", "dropout_zero_rule", @@ -14,6 +16,7 @@ "fuse_batchnorm_into_conv_rule", "fuse_batchnorm_into_conv_transpose_rule", "fuse_batchnorm_into_gemm_rule", + "fuse_hardswish_rules", "fuse_pad_into_conv_integer_rule", "fuse_pad_into_conv_rule", "min_min_rule", @@ -76,6 +79,11 @@ fuse_batchnorm_into_conv_transpose_rule, fuse_batchnorm_into_gemm_rule, ) +from onnxscript.rewriter.rules.common._fuse_conv_affine import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) +from onnxscript.rewriter.rules.common._fuse_hardswish import fuse_hardswish_rules from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( fuse_pad_into_conv_integer_rule, fuse_pad_into_conv_rule, diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py new file mode 100644 index 0000000000..2aaba5cd73 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Absorbs affine operation into convolution (best effort): +- Conv(Mul(Add(x))) -> Conv (only conv without padding can be fused) +- Add(Mul(Conv)) -> Conv (for all convolutions) +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import get_const_value, get_singleton_value + + +class _ConvAffineFusionBase(pattern.RewriteRuleClassBase): + def check( + self, + context, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + if get_const_value(w) is None: + return check_result.fail("The weight of Conv should be constant") + if get_const_value(b) is None: + return check_result.fail("The bias of Conv should be constant") + if get_singleton_value(scale) is None: + return check_result.fail("Operand for Mul should be constant scalar value") + if get_singleton_value(offset) is None: + return check_result.fail("Operand for Add should be constant scalar value") + return check_result + + +class AffineConvFusion(_ConvAffineFusionBase): + """Pattern: scalar Mul + scalar Add + Conv (1x1) --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return op.Conv( + x * scale + offset, + w, + b, + pads=[0, 0, 0, 0], + _allow_other_attributes=True, + _outputs=["conv_out"], + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_value = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor( + b_value + np.sum(w_value * offset_value, axis=(1, 2, 3), keepdims=False) + ) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_value, offset_bias, **conv_attributes) + + +class ConvAffineFusion(_ConvAffineFusionBase): + """Pattern: Conv + scalar Mul + scalar Add --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return ( + op.Conv(x, w, b, _allow_other_attributes=True, _outputs=["conv_out"]) * scale + + offset + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_weight = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor(b_value * scale_value + offset_value) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_weight, offset_bias, **conv_attributes) + + +affine_conv_fusion_rule = AffineConvFusion().rule() +conv_affine_fusion_rule = ConvAffineFusion().rule() diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py new file mode 100644 index 0000000000..4f1f671f43 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import rewrite, testing +from onnxscript.rewriter.rules.common import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) + + +class FuseConvAffineTest(unittest.TestCase): + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def test_conv_affine_fusion(self): + tape = ir.tape.Tape() + x = ir.Input( + "x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT) + ) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + conv_out = tape.op("Conv", [x, w, b], attributes={"pads": [1, 1, 1, 1]}) + mul_out = tape.op("Mul", [conv_out, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.Input( + "z", + shape=ir.Shape([1, 3, 32, 32]), + type=ir.TensorType(ir.DataType.FLOAT), + ), + ) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[z], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[conv_affine_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + def test_affine_conv_fusion_without_pad(self): + tape = ir.tape.Tape() + x = ir.Input( + "x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT) + ) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + mul_out = tape.op("Mul", [x, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.Input( + "z", + shape=ir.Shape([1, 3, 32, 32]), + type=ir.TensorType(ir.DataType.FLOAT), + ), + ) + conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]}) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[conv_out], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[affine_conv_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish.py b/onnxscript/rewriter/rules/common/_fuse_hardswish.py new file mode 100644 index 0000000000..6d2e8c84e1 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Div(Clip(Add(x))) -> HardSigmoid +- Mul(HardSigmoid(x), x) -> HardSwish +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import is_singleton_value +from onnxscript.rewriter._rewrite_rule import RewriteRuleSet + + +class _HardSigmoidFusionBase(pattern.RewriteRuleClassBase): + """HardSwish requires constant values so we check in base class.""" + + def check( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + + if not is_singleton_value(clip_min, 0.0, rtol=1e-4): + return check_result.fail("Swish requires min value of 0 for clip") + if not is_singleton_value(clip_max, 6.0, rtol=1e-4): + return check_result.fail("Swish requires max value of 6 for clip") + if not is_singleton_value(bias, 3.0, rtol=1e-4): + return check_result.fail("Swish requires bias value of 3") + if not is_singleton_value(divisor, 6.0, rtol=1e-4): + return check_result.fail("Swish requires divisor value of 6") + return check_result + + +class HardSwishFusion(_HardSigmoidFusionBase): + """Fuse Add(_, 3) + Clip<0, 6>(_) + Mul + Div(_, 6) into HardSwish + + In this case we can't make HardSigmoid fusion first. The Mul + is placed before Div while HardSigmoid requires Add+Clip+Div. + """ + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) * x + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSwish(x) + + +class HardSwishFusionFromHardSigmoid(pattern.RewriteRuleClassBase): + """Fuse HardSigmoid + Mul into HardSwish""" + + def pattern(self, op, x: ir.Value) -> ir.Value: + # Floating point matching for 1/6 is not exact, so we use isclose below + out = op.HardSigmoid(x, _allow_other_attributes=True, _outputs=["hardsigmoid_out"]) + out = out * x + return out + + def check(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> MatchResult: + check_result = MatchResult() + hardsigmoid = hardsigmoid_out.producer() + # Use getter to protect when 'alpha' / 'beta' is not in attributes + alpha = hardsigmoid.attributes.get_float("alpha", -1) + beta = hardsigmoid.attributes.get_float("beta", -1) + if not np.isclose(alpha, 1 / 6): + return check_result.fail( + "HardSigmoid alpha must be 1/6 to get fused into HardSwish" + ) + if not np.isclose(beta, 0.5): + return check_result.fail( + "HardSigmoid beta must be 0.5 to get fused into HardSwish" + ) + return check_result + + def rewrite(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> ir.Value: + return op.HardSwish(x) + + +class HardSigmoidFusion(_HardSigmoidFusionBase): + """Fuse HardSigmoid only for HardSwish hyper-parameters: alpha=1/6, beta=0.5""" + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSigmoid(x, alpha=1 / 6, beta=0.5) + + +def fuse_hardswish_rules() -> RewriteRuleSet: + """Returns the rewrite rules for fusing HardSwish and HardSigmoid.""" + return RewriteRuleSet( + [ + HardSwishFusion().rule(), + HardSigmoidFusion().rule(), + HardSwishFusionFromHardSigmoid().rule(), + ], + commute=True, + ) diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py new file mode 100644 index 0000000000..36556e9cff --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +import onnxruntime as ort +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript import optimizer +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import fuse_hardswish_rules + + +class FuseHardSwishTest(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250621) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = fuse_hardswish_rules().apply_to_model(updated_model) + + # Polish model to remove unused constants + updated_model = optimizer.optimize(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = (self.rng.integers(low=-10, high=10, size=(2 * 32), dtype=np.int32),) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ort_optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + + # Validate serialized model + output_model_proto = ir.to_proto(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_hardsigmoid_fusion(self): + model_text = """ + + hardsigmoid (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + y = Div(clipped, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSigmoid"]) + + def test_hardswish_fusion(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + mul_x = Mul(clipped, x) + y = Div(mul_x, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_mul_last(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + div_x = Div(clipped, six) + y = Mul(div_x, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_from_sigmoid(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + hardsigmoid_out = HardSigmoid(x) + y = Mul(hardsigmoid_out, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + +if __name__ == "__main__": + unittest.main()