Skip to content

Commit 6c092e2

Browse files
Copilotxadupre
andauthored
Add fusion rule to remove Expand before broadcast-capable binary operators (#2862)
Adds a rewrite rule that eliminates redundant `Expand` nodes preceding binary operators that natively support NumPy-style broadcasting. ## Pattern ``` BinaryOp(Expand(x, shape), y) → BinaryOp(x, y) BinaryOp(x, Expand(y, shape)) → BinaryOp(x, y) ``` ## Safety check The rule applies a dimension-by-dimension analysis to determine if the `Expand` is redundant. For each dimension `i`, the expand is safe to remove if any of the following hold: - `expand_shape[i] == 1` - expand cannot shrink a dimension, so it is a no-op. - `x.shape[i] == expand_shape[i]` - the expand is a no-op at this dimension. - `y.shape[i] == expand_shape[i]` - `y` already covers the expansion via its own broadcasting. Otherwise the check fails conservatively. Three producer-agnostic strategies are used to resolve the expand target shape: 1. **Constant expand shape**: When the `shape` argument is a compile-time constant, the check is applied directly. Individual dimensions of `x` or `y` may still be symbolic. For example, `Add(Expand(x=[N], shape=[3,4]), y=[3,4])` is optimized to `Add(x, y)` because `y` statically provides all expansion dimensions. 2. **Expand output shape annotation**: When `shape` is dynamic but the Expand node's output value already carries a shape annotation (e.g. after ONNX shape inference has been applied), those dimension values are used directly for the check. For example, after `onnx.shape_inference.infer_shapes`, `Expand(x=[N,1], Concat(Shape(x,0:1), Shape(x,1:2)))` gets output shape `[N,1]` and the rule fires. 3. **Binary op output shape**: When neither of the above is available, the rule verifies that `broadcast(x.shape, y.shape)` symbolically equals the binary op's output shape. If they agree, the binary op's own broadcasting already accounts for all the expansion and the Expand is redundant. ## Supported ops `Add`, `Sub`, `Mul`, `Div`, `Pow`, `And`, `Or`, `Xor`, `BitwiseAnd`, `BitwiseOr`, `BitwiseXor`, `Greater`, `Less`, `Equal`, `GreaterOrEqual`, `LessOrEqual`, `Mod`, `PRelu`, `BitShift` ## Changes - **`_remove_expand_before_binary_op.py`** — new module with `_ExpandFirstInput` / `_ExpandSecondInput` rule classes, `_compute_broadcast_shape` / `_check_dims_sufficient` helpers, and the exported `expand_before_binary_op_rules` `RewriteRuleSet`; rule classes access `context.root` to obtain the Expand output and binary op output values - **`_remove_expand_before_binary_op_test.py`** — tests covering removal when safe (including dynamic shapes via shape annotations and binary op output shape matching), and non-removal when the expansion cannot be statically verified - **`rules/common/__init__.py`** — exports `expand_before_binary_op_rules` <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>create a fusion rule to remove an expand node before a binary operator if this op can handle it through broadcasting</issue_title> > <issue_description></issue_description> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes #2861 <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/microsoft/onnxscript/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Signed-off-by: Xavier Dupré <xadupre@microsoft.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> Co-authored-by: Xavier Dupré <xadupre@microsoft.com> Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
1 parent c7d13fb commit 6c092e2

File tree

3 files changed

+647
-0
lines changed

3 files changed

+647
-0
lines changed

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"div_by_1_rule",
1313
"dropout_inference_rule",
1414
"dropout_zero_rule",
15+
"expand_before_binary_op_rules",
1516
"flatten_to_reshape_rule",
1617
"fuse_batchnorm_into_conv_rule",
1718
"fuse_batchnorm_into_conv_transpose_rule",
@@ -125,6 +126,9 @@
125126
no_op_dynamic_scatter_nd_rule,
126127
no_op_static_scatter_nd_rule,
127128
)
129+
from onnxscript.rewriter.rules.common._remove_expand_before_binary_op import (
130+
expand_before_binary_op_rules,
131+
)
128132
from onnxscript.rewriter.rules.common._remove_optional_bias import (
129133
remove_optional_bias_from_conv_rule,
130134
remove_optional_bias_from_conv_transpose_rule,
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fusion rule to remove an Expand node before a binary operator.
4+
5+
This implements the optimization:
6+
7+
BinaryOp(Expand(x, shape), y) -> BinaryOp(x, y)
8+
BinaryOp(x, Expand(y, shape)) -> BinaryOp(x, y)
9+
10+
This is valid when the binary operator's broadcasting semantics would produce
11+
the same output shape as first expanding the input and then applying the op.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
from onnxscript import ir
17+
from onnxscript.rewriter._basics import MatchResult
18+
from onnxscript.rewriter._ir_utils import get_numpy_value
19+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
20+
21+
# Binary operators in ONNX standard opset that support numpy-style broadcasting.
22+
_BROADCAST_BINARY_OPS: tuple[str, ...] = (
23+
"Add",
24+
"And",
25+
"BitShift",
26+
"BitwiseAnd",
27+
"BitwiseOr",
28+
"BitwiseXor",
29+
"Div",
30+
"Equal",
31+
"Greater",
32+
"GreaterOrEqual",
33+
"Less",
34+
"LessOrEqual",
35+
"Mod",
36+
"Mul",
37+
"Or",
38+
"Pow",
39+
"PRelu",
40+
"Sub",
41+
"Xor",
42+
)
43+
44+
45+
def _compute_broadcast_dim(d1, d2):
46+
"""Return the numpy broadcast of two dimension values.
47+
48+
Each dimension value may be an ``int`` or an ``onnx_ir.SymbolicDim``.
49+
Returns ``None`` when the result cannot be determined statically (e.g. two
50+
distinct symbolic values neither of which is known to be 1).
51+
"""
52+
if d1 == 1:
53+
return d2
54+
if d2 == 1:
55+
return d1
56+
if d1 == d2:
57+
return d1
58+
return None
59+
60+
61+
def _compute_broadcast_shape(shape1: ir.Shape, shape2: ir.Shape) -> list | None:
62+
"""Compute numpy-style broadcast shape symbolically.
63+
64+
Returns the broadcast shape as a list of dimension values (``int`` or
65+
``SymbolicDim``), or ``None`` when the result cannot be determined (e.g.
66+
unknown ranks or incompatible static dims).
67+
"""
68+
rank1 = shape1.rank()
69+
rank2 = shape2.rank()
70+
if rank1 is None or rank2 is None:
71+
return None
72+
rank = max(rank1, rank2)
73+
result = []
74+
for i in range(rank):
75+
idx1 = rank1 - rank + i
76+
d1 = shape1[idx1] if idx1 >= 0 else 1
77+
idx2 = rank2 - rank + i
78+
d2 = shape2[idx2] if idx2 >= 0 else 1
79+
d = _compute_broadcast_dim(d1, d2)
80+
if d is None:
81+
return None
82+
result.append(d)
83+
return result
84+
85+
86+
def _check_dims_sufficient(
87+
expand_shape: ir.Shape,
88+
x_shape: ir.Shape,
89+
y_shape: ir.Shape,
90+
) -> MatchResult:
91+
"""Check that x and y together cover every dimension of the expand target.
92+
93+
For each dimension ``i`` of *expand_shape* (right-aligned) the expand is
94+
considered redundant when at least one of the following holds:
95+
96+
- ``expand_shape[i] == 1`` - expand cannot shrink a dim, so ``x_d`` must
97+
also be 1 and both with and without expand produce ``y_d``.
98+
- ``x_d == expand_shape[i]`` - the expand is a no-op at this dim.
99+
- ``y_d == expand_shape[i]`` - ``y`` already supplies this expansion.
100+
101+
Comparisons work for both ``int`` and ``SymbolicDim`` values.
102+
"""
103+
check_result = MatchResult()
104+
e_rank = expand_shape.rank()
105+
x_rank = x_shape.rank()
106+
y_rank = y_shape.rank()
107+
if e_rank is None:
108+
return check_result.fail("Expand output rank is unknown.")
109+
110+
for rev_i in range(e_rank):
111+
i = e_rank - 1 - rev_i
112+
e_d = expand_shape[i]
113+
114+
if isinstance(e_d, int) and e_d == 1:
115+
continue # expand cannot shrink; x_d is also 1, no-op
116+
117+
x_idx = x_rank - 1 - rev_i
118+
x_d = x_shape[x_idx] if x_idx >= 0 else 1
119+
if x_d == e_d:
120+
continue # expand is a no-op at this dimension
121+
122+
y_idx = y_rank - 1 - rev_i
123+
y_d = y_shape[y_idx] if y_idx >= 0 else 1
124+
if y_d == e_d:
125+
continue # y already supplies this dimension
126+
127+
return check_result.fail(
128+
f"Cannot verify that removing Expand is safe at dimension {i}: "
129+
f"x_d={x_d!r}, expand_d={e_d!r}, y_d={y_d!r}."
130+
)
131+
132+
return check_result
133+
134+
135+
def _check_expand_removable(
136+
expand_input: ir.Value,
137+
shape: ir.Value,
138+
other_input: ir.Value,
139+
expand_output: ir.Value | None = None,
140+
binary_op_output: ir.Value | None = None,
141+
) -> MatchResult:
142+
"""Check if an Expand node can be safely removed before a binary op.
143+
144+
The Expand ``expanded_x = Expand(x, expand_shape)`` before a binary op
145+
``out = BinaryOp(expanded_x, y)`` is redundant when the binary op's own
146+
broadcasting produces the same output as if the expand had been applied.
147+
148+
Three strategies are tried in order:
149+
150+
1. **Constant expand shape** - When ``shape`` is a compile-time constant,
151+
the dimension values are extracted from it and the check is performed
152+
directly.
153+
154+
2. **Expand output shape annotation** - When ``shape`` is dynamic but the
155+
Expand node's output value already carries a shape annotation (e.g.
156+
after ONNX shape inference has been applied to the model), those
157+
dimension values are used for the check.
158+
159+
3. **Binary op output shape** - When neither of the above is available,
160+
the rule verifies that ``broadcast(x.shape, y.shape)`` symbolically
161+
equals the binary op's output shape. If they agree, the binary op's
162+
own broadcasting already accounts for all the expansion and the
163+
Expand is redundant.
164+
165+
Args:
166+
expand_input: The value fed into the Expand node (``x``).
167+
shape: The target shape operand of the Expand node.
168+
other_input: The other operand of the binary op (``y``).
169+
expand_output: The output value of the Expand node. Required for
170+
strategy 2.
171+
binary_op_output: The output value of the binary op. Required for
172+
strategy 3.
173+
174+
Returns:
175+
A :class:`MatchResult` that is successful when the Expand can be
176+
removed.
177+
"""
178+
check_result = MatchResult()
179+
180+
x_shape = expand_input.shape
181+
y_shape = other_input.shape
182+
if x_shape is None or y_shape is None:
183+
return check_result.fail("Input shapes are not known.")
184+
185+
x_rank = x_shape.rank()
186+
y_rank = y_shape.rank()
187+
188+
# --- Strategy 1: expand target shape is a compile-time constant ---
189+
expand_shape_val = get_numpy_value(shape)
190+
if expand_shape_val is not None:
191+
expand_shape = tuple(int(v) for v in expand_shape_val.tolist())
192+
expand_rank = len(expand_shape)
193+
194+
for rev_i in range(expand_rank):
195+
i = expand_rank - 1 - rev_i
196+
e_d = expand_shape[i] # always a known integer from numpy
197+
198+
if e_d == 1:
199+
continue # expand cannot shrink; x_d is also 1, no-op
200+
201+
x_idx = x_rank - 1 - rev_i
202+
x_d = x_shape[x_idx] if x_idx >= 0 else 1
203+
204+
if isinstance(x_d, int) and x_d == e_d:
205+
continue # expand is a no-op at this dimension
206+
207+
y_idx = y_rank - 1 - rev_i
208+
y_d = y_shape[y_idx] if y_idx >= 0 else 1
209+
210+
if isinstance(y_d, int) and y_d == e_d:
211+
continue # y already supplies this dimension
212+
213+
return check_result.fail(
214+
f"Cannot verify that removing Expand is safe at dimension {i}: "
215+
f"x_d={x_d!r}, expand_d={e_d}, y_d={y_d!r}."
216+
)
217+
218+
return check_result
219+
220+
# --- Strategy 2: Expand output shape is known (e.g. from shape inference) ---
221+
if expand_output is not None and expand_output.shape is not None:
222+
return _check_dims_sufficient(expand_output.shape, x_shape, y_shape)
223+
224+
# --- Strategy 3: use the binary op's output shape ---
225+
# broadcast(x.shape, y.shape) must equal the binary op's output shape.
226+
# If it does, the binary op's own broadcasting already produces the same
227+
# result as first expanding x and then broadcasting.
228+
if binary_op_output is not None and binary_op_output.shape is not None:
229+
op_output_shape = binary_op_output.shape
230+
if op_output_shape.rank() is not None:
231+
computed = _compute_broadcast_shape(x_shape, y_shape)
232+
if computed is not None and len(computed) == op_output_shape.rank():
233+
if all(c == a for c, a in zip(computed, op_output_shape)):
234+
return check_result
235+
return check_result.fail(
236+
"broadcast(x.shape, y.shape) does not match the binary op output shape."
237+
)
238+
239+
return check_result.fail(
240+
"Expand target shape is not a constant and no shape annotations are available."
241+
)
242+
243+
244+
class _ExpandFirstInput(RewriteRuleClassBase):
245+
"""Removes ``BinaryOp(Expand(x, shape), y)`` -> ``BinaryOp(x, y)``."""
246+
247+
def __init__(self, op_type: str) -> None:
248+
super().__init__(f"ExpandFirst_{op_type}", remove_nodes=False)
249+
self._op_type = op_type
250+
251+
def pattern(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
252+
return getattr(op, self._op_type)(op.Expand(x, shape), y)
253+
254+
def check(self, context, x: ir.Value, shape: ir.Value, y: ir.Value) -> MatchResult:
255+
expand_output = context.root.inputs[0] if context.root.inputs else None
256+
binary_op_output = context.root.outputs[0] if context.root.outputs else None
257+
return _check_expand_removable(
258+
x, shape, y, expand_output=expand_output, binary_op_output=binary_op_output
259+
)
260+
261+
def rewrite(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
262+
return getattr(op, self._op_type)(x, y)
263+
264+
265+
class _ExpandSecondInput(RewriteRuleClassBase):
266+
"""Removes ``BinaryOp(x, Expand(y, shape))`` -> ``BinaryOp(x, y)``."""
267+
268+
def __init__(self, op_type: str) -> None:
269+
super().__init__(f"ExpandSecond_{op_type}", remove_nodes=False)
270+
self._op_type = op_type
271+
272+
def pattern(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
273+
return getattr(op, self._op_type)(x, op.Expand(y, shape))
274+
275+
def check(self, context, x: ir.Value, y: ir.Value, shape: ir.Value) -> MatchResult:
276+
expand_output = context.root.inputs[1] if context.root.inputs else None
277+
binary_op_output = context.root.outputs[0] if context.root.outputs else None
278+
return _check_expand_removable(
279+
y, shape, x, expand_output=expand_output, binary_op_output=binary_op_output
280+
)
281+
282+
def rewrite(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
283+
return getattr(op, self._op_type)(x, y)
284+
285+
286+
def _make_expand_before_binary_op_rules() -> list:
287+
"""Create rewrite rules for removing Expand before each supported binary op."""
288+
rules = []
289+
for op_type in _BROADCAST_BINARY_OPS:
290+
rules.append(_ExpandFirstInput.rule(op_type))
291+
rules.append(_ExpandSecondInput.rule(op_type))
292+
return rules
293+
294+
295+
expand_before_binary_op_rules = RewriteRuleSet(_make_expand_before_binary_op_rules())

0 commit comments

Comments
 (0)