Skip to content

Commit 161bcd0

Browse files
Copilotxadupre
andcommitted
Add fusion rule to remove Expand node before broadcast-capable binary operators
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent a5d2384 commit 161bcd0

File tree

3 files changed

+379
-0
lines changed

3 files changed

+379
-0
lines changed

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"div_by_1_rule",
1313
"dropout_inference_rule",
1414
"dropout_zero_rule",
15+
"expand_before_binary_op_rules",
1516
"flatten_to_reshape_rule",
1617
"fuse_batchnorm_into_conv_rule",
1718
"fuse_batchnorm_into_conv_transpose_rule",
@@ -125,6 +126,9 @@
125126
no_op_dynamic_scatter_nd_rule,
126127
no_op_static_scatter_nd_rule,
127128
)
129+
from onnxscript.rewriter.rules.common._remove_expand_before_binary_op import (
130+
expand_before_binary_op_rules,
131+
)
128132
from onnxscript.rewriter.rules.common._remove_optional_bias import (
129133
remove_optional_bias_from_conv_rule,
130134
remove_optional_bias_from_conv_transpose_rule,
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fusion rule to remove an Expand node before a binary operator.
4+
5+
This implements the optimization:
6+
7+
BinaryOp(Expand(x, shape), y) -> BinaryOp(x, y)
8+
BinaryOp(x, Expand(y, shape)) -> BinaryOp(x, y)
9+
10+
This is valid when the binary operator's broadcasting semantics would produce
11+
the same output shape as first expanding the input and then applying the op.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import numpy as np
17+
18+
from onnxscript import ir
19+
from onnxscript.rewriter._basics import MatchResult
20+
from onnxscript.rewriter._ir_utils import get_numpy_value
21+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
22+
23+
# Binary operators in ONNX standard opset that support numpy-style broadcasting.
24+
_BROADCAST_BINARY_OPS: tuple[str, ...] = (
25+
"Add",
26+
"And",
27+
"BitShift",
28+
"BitwiseAnd",
29+
"BitwiseOr",
30+
"BitwiseXor",
31+
"Div",
32+
"Equal",
33+
"Greater",
34+
"GreaterOrEqual",
35+
"Less",
36+
"LessOrEqual",
37+
"Mod",
38+
"Mul",
39+
"Or",
40+
"Pow",
41+
"PRelu",
42+
"Sub",
43+
"Xor",
44+
)
45+
46+
47+
def _check_expand_removable(
48+
expand_input: ir.Value,
49+
shape: ir.Value,
50+
other_input: ir.Value,
51+
) -> MatchResult:
52+
"""Check if an Expand node can be safely removed before a binary op.
53+
54+
The Expand is removable if the binary op's broadcasting produces the same
55+
output shape when using the original (pre-expand) tensor directly.
56+
57+
Args:
58+
expand_input: The value fed into the Expand node.
59+
shape: The target shape operand of the Expand node (must be a constant).
60+
other_input: The other operand of the binary op.
61+
62+
Returns:
63+
A MatchResult that is successful when the Expand can be removed.
64+
"""
65+
check_result = MatchResult()
66+
67+
# Need static shape info for both inputs.
68+
expand_input_shape = expand_input.shape
69+
other_shape = other_input.shape
70+
if expand_input_shape is None or other_shape is None:
71+
return check_result.fail("Input shapes are not statically known.")
72+
73+
# Require fully static (integer-only) shapes to avoid symbolic dim issues.
74+
if not expand_input_shape.is_static() or not other_shape.is_static():
75+
return check_result.fail("Input shapes are not fully static.")
76+
77+
# The Expand target shape must be a compile-time constant.
78+
expand_shape_val = get_numpy_value(shape)
79+
if expand_shape_val is None:
80+
return check_result.fail("Expand target shape is not a constant.")
81+
82+
expand_shape = tuple(int(v) for v in expand_shape_val.tolist())
83+
x_shape = tuple(int(d) for d in expand_input_shape)
84+
y_shape = tuple(int(d) for d in other_shape)
85+
86+
# Verify that removing the Expand does not change the binary op's output shape.
87+
try:
88+
result_with_expand = np.broadcast_shapes(expand_shape, y_shape)
89+
result_without_expand = np.broadcast_shapes(x_shape, y_shape)
90+
except ValueError:
91+
return check_result.fail("Shapes are not broadcastable.")
92+
93+
if result_with_expand != result_without_expand:
94+
return check_result.fail(
95+
f"Removing Expand would change output shape from "
96+
f"{result_with_expand} to {result_without_expand}."
97+
)
98+
99+
return check_result
100+
101+
102+
class _ExpandFirstInput(RewriteRuleClassBase):
103+
"""Removes ``BinaryOp(Expand(x, shape), y)`` -> ``BinaryOp(x, y)``."""
104+
105+
def __init__(self, op_type: str) -> None:
106+
super().__init__(f"ExpandFirst_{op_type}", remove_nodes=False)
107+
self._op_type = op_type
108+
109+
def pattern(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
110+
return getattr(op, self._op_type)(op.Expand(x, shape), y)
111+
112+
def check(self, context, x: ir.Value, shape: ir.Value, y: ir.Value) -> MatchResult:
113+
del context # Unused
114+
return _check_expand_removable(x, shape, y)
115+
116+
def rewrite(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
117+
return getattr(op, self._op_type)(x, y)
118+
119+
120+
class _ExpandSecondInput(RewriteRuleClassBase):
121+
"""Removes ``BinaryOp(x, Expand(y, shape))`` -> ``BinaryOp(x, y)``."""
122+
123+
def __init__(self, op_type: str) -> None:
124+
super().__init__(f"ExpandSecond_{op_type}", remove_nodes=False)
125+
self._op_type = op_type
126+
127+
def pattern(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
128+
return getattr(op, self._op_type)(x, op.Expand(y, shape))
129+
130+
def check(self, context, x: ir.Value, y: ir.Value, shape: ir.Value) -> MatchResult:
131+
del context # Unused
132+
return _check_expand_removable(y, shape, x)
133+
134+
def rewrite(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
135+
return getattr(op, self._op_type)(x, y)
136+
137+
138+
def _make_expand_before_binary_op_rules() -> list:
139+
"""Create rewrite rules for removing Expand before each supported binary op."""
140+
rules = []
141+
for op_type in _BROADCAST_BINARY_OPS:
142+
rules.append(_ExpandFirstInput.rule(op_type))
143+
rules.append(_ExpandSecondInput.rule(op_type))
144+
return rules
145+
146+
147+
expand_before_binary_op_rules = RewriteRuleSet(_make_expand_before_binary_op_rules())
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Tests for the remove-Expand-before-binary-op fusion rule."""
4+
5+
from __future__ import annotations
6+
7+
import unittest
8+
9+
import numpy as np
10+
import onnx.reference
11+
import parameterized
12+
13+
import onnxscript.ir as ir
14+
from onnxscript.rewriter.rules.common import _remove_expand_before_binary_op as mod
15+
16+
17+
def _run_model(model: ir.Model, feeds: dict) -> list:
18+
"""Run a model using the ONNX reference evaluator."""
19+
proto = ir.to_proto(model)
20+
ref = onnx.reference.ReferenceEvaluator(proto)
21+
return ref.run(None, feeds)
22+
23+
24+
class RemoveExpandBeforeBinaryOpTest(unittest.TestCase):
25+
"""Tests for _remove_expand_before_binary_op rules."""
26+
27+
def _apply_and_check(
28+
self,
29+
model_text: str,
30+
expected_count: int,
31+
expected_op_types: list[str],
32+
) -> ir.Model:
33+
"""Helper: apply the rules and verify the result."""
34+
model = ir.from_onnx_text(model_text)
35+
count = mod.expand_before_binary_op_rules.apply_to_model(model)
36+
self.assertEqual(count, expected_count)
37+
actual_op_types = [node.op_type for node in model.graph]
38+
self.assertEqual(actual_op_types, expected_op_types)
39+
return model
40+
41+
# ------------------------------------------------------------------
42+
# Cases where the Expand should be removed
43+
# ------------------------------------------------------------------
44+
45+
@parameterized.parameterized.expand(
46+
[
47+
("Add",),
48+
("Sub",),
49+
("Mul",),
50+
("Div",),
51+
]
52+
)
53+
def test_expand_first_input_same_shape_is_removed(self, op_type: str):
54+
"""Expand producing same shape as input should be removed from BinaryOp."""
55+
model_text = f"""
56+
<ir_version: 7, opset_import: [ "" : 17]>
57+
agraph (float[3, 4] x, float[3, 4] y) => (float[3, 4] output)
58+
<int64[2] shape = {{3, 4}}>
59+
{{
60+
expanded = Expand(x, shape)
61+
output = {op_type}(expanded, y)
62+
}}
63+
"""
64+
model = self._apply_and_check(model_text, 1, [op_type])
65+
66+
# Verify numerical correctness
67+
x = np.random.randn(3, 4).astype(np.float32)
68+
y = np.random.randn(3, 4).astype(np.float32)
69+
original = ir.from_onnx_text(model_text)
70+
expected = _run_model(original, {"x": x, "y": y})
71+
got = _run_model(model, {"x": x, "y": y})
72+
np.testing.assert_allclose(got[0], expected[0], rtol=1e-5)
73+
74+
def test_expand_first_input_broadcast_covered_by_other_input(self):
75+
"""Expand from [3, 4] to [4, 3, 4] can be removed when y has shape [4, 3, 4]."""
76+
model_text = """
77+
<ir_version: 7, opset_import: [ "" : 17]>
78+
agraph (float[3, 4] x, float[4, 3, 4] y) => (float[4, 3, 4] output)
79+
<int64[3] shape = {4, 3, 4}>
80+
{
81+
expanded = Expand(x, shape)
82+
output = Add(expanded, y)
83+
}
84+
"""
85+
model = self._apply_and_check(model_text, 1, ["Add"])
86+
87+
x = np.random.randn(3, 4).astype(np.float32)
88+
y = np.random.randn(4, 3, 4).astype(np.float32)
89+
original = ir.from_onnx_text(model_text)
90+
expected = _run_model(original, {"x": x, "y": y})
91+
got = _run_model(model, {"x": x, "y": y})
92+
np.testing.assert_allclose(got[0], expected[0], rtol=1e-5)
93+
94+
def test_expand_second_input_is_removed(self):
95+
"""Expand on the second input of a binary op should be removed."""
96+
model_text = """
97+
<ir_version: 7, opset_import: [ "" : 17]>
98+
agraph (float[4, 3, 4] x, float[3, 4] y) => (float[4, 3, 4] output)
99+
<int64[3] shape = {4, 3, 4}>
100+
{
101+
expanded = Expand(y, shape)
102+
output = Mul(x, expanded)
103+
}
104+
"""
105+
model = self._apply_and_check(model_text, 1, ["Mul"])
106+
107+
x = np.random.randn(4, 3, 4).astype(np.float32)
108+
y = np.random.randn(3, 4).astype(np.float32)
109+
original = ir.from_onnx_text(model_text)
110+
expected = _run_model(original, {"x": x, "y": y})
111+
got = _run_model(model, {"x": x, "y": y})
112+
np.testing.assert_allclose(got[0], expected[0], rtol=1e-5)
113+
114+
def test_expand_with_broadcast_compatible_other_input(self):
115+
"""Expand from [3] to [4, 3] can be removed when y has shape [4, 1]."""
116+
model_text = """
117+
<ir_version: 7, opset_import: [ "" : 17]>
118+
agraph (float[3] x, float[4, 1] y) => (float[4, 3] output)
119+
<int64[2] shape = {4, 3}>
120+
{
121+
expanded = Expand(x, shape)
122+
output = Add(expanded, y)
123+
}
124+
"""
125+
model = self._apply_and_check(model_text, 1, ["Add"])
126+
127+
x = np.random.randn(3).astype(np.float32)
128+
y = np.random.randn(4, 1).astype(np.float32)
129+
original = ir.from_onnx_text(model_text)
130+
expected = _run_model(original, {"x": x, "y": y})
131+
got = _run_model(model, {"x": x, "y": y})
132+
np.testing.assert_allclose(got[0], expected[0], rtol=1e-5)
133+
134+
def test_expand_sub_first_input_is_removed(self):
135+
"""Expand on the first input of Sub should be removed."""
136+
model_text = """
137+
<ir_version: 7, opset_import: [ "" : 17]>
138+
agraph (float[3, 4] x, float[3, 4] y) => (float[3, 4] output)
139+
<int64[2] shape = {3, 4}>
140+
{
141+
expanded = Expand(x, shape)
142+
output = Sub(expanded, y)
143+
}
144+
"""
145+
model = self._apply_and_check(model_text, 1, ["Sub"])
146+
147+
x = np.random.randn(3, 4).astype(np.float32)
148+
y = np.random.randn(3, 4).astype(np.float32)
149+
original = ir.from_onnx_text(model_text)
150+
expected = _run_model(original, {"x": x, "y": y})
151+
got = _run_model(model, {"x": x, "y": y})
152+
np.testing.assert_allclose(got[0], expected[0], rtol=1e-5)
153+
154+
def test_expand_div_second_input_is_removed(self):
155+
"""Expand on the second input of Div should be removed."""
156+
model_text = """
157+
<ir_version: 7, opset_import: [ "" : 17]>
158+
agraph (float[4, 3, 4] x, float[3, 4] y) => (float[4, 3, 4] output)
159+
<int64[3] shape = {4, 3, 4}>
160+
{
161+
expanded = Expand(y, shape)
162+
output = Div(x, expanded)
163+
}
164+
"""
165+
model = self._apply_and_check(model_text, 1, ["Div"])
166+
167+
x = np.random.randn(4, 3, 4).astype(np.float32)
168+
y = (np.random.randn(3, 4).astype(np.float32) + 2.0) # avoid division by zero
169+
original = ir.from_onnx_text(model_text)
170+
expected = _run_model(original, {"x": x, "y": y})
171+
got = _run_model(model, {"x": x, "y": y})
172+
np.testing.assert_allclose(got[0], expected[0], rtol=1e-5)
173+
174+
# ------------------------------------------------------------------
175+
# Cases where the Expand should NOT be removed
176+
# ------------------------------------------------------------------
177+
178+
def test_expand_changes_output_shape_not_removed(self):
179+
"""Expand that changes the output shape compared to direct broadcast must be kept."""
180+
# x has shape [3], expand to [4, 3], other is a scalar.
181+
# With expand: broadcast([4, 3], []) = [4, 3]
182+
# Without expand: broadcast([3], []) = [3] <- different!
183+
model_text = """
184+
<ir_version: 7, opset_import: [ "" : 17]>
185+
agraph (float[3] x) => (float[4, 3] output)
186+
<int64[2] shape = {4, 3}, float[1] one = {1.0}>
187+
{
188+
expanded = Expand(x, shape)
189+
output = Add(expanded, one)
190+
}
191+
"""
192+
model = ir.from_onnx_text(model_text)
193+
count = mod.expand_before_binary_op_rules.apply_to_model(model)
194+
self.assertEqual(count, 0)
195+
196+
def test_expand_target_shape_not_constant_not_removed(self):
197+
"""Expand with a dynamic (non-constant) shape cannot be removed."""
198+
model_text = """
199+
<ir_version: 7, opset_import: [ "" : 17]>
200+
agraph (float[3, 4] x, float[3, 4] y, int64[2] shape) => (float[3, 4] output)
201+
{
202+
expanded = Expand(x, shape)
203+
output = Add(expanded, y)
204+
}
205+
"""
206+
model = ir.from_onnx_text(model_text)
207+
count = mod.expand_before_binary_op_rules.apply_to_model(model)
208+
self.assertEqual(count, 0)
209+
210+
def test_expand_unknown_input_shape_not_removed(self):
211+
"""Expand cannot be removed when the input shape is not statically known."""
212+
# No shape annotation on 'x'
213+
model_text = """
214+
<ir_version: 7, opset_import: [ "" : 17]>
215+
agraph (float[N] x, float[3, 4] y) => (float[3, 4] output)
216+
<int64[2] shape = {3, 4}>
217+
{
218+
expanded = Expand(x, shape)
219+
output = Add(expanded, y)
220+
}
221+
"""
222+
model = ir.from_onnx_text(model_text)
223+
count = mod.expand_before_binary_op_rules.apply_to_model(model)
224+
self.assertEqual(count, 0)
225+
226+
227+
if __name__ == "__main__":
228+
unittest.main()

0 commit comments

Comments
 (0)