Skip to content

Commit aa0a89b

Browse files
Copilotxadupre
andcommitted
Support dynamic shapes in expand-before-binary-op fusion rule
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent 161bcd0 commit aa0a89b

2 files changed

Lines changed: 77 additions & 24 deletions

File tree

onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
from __future__ import annotations
1515

16-
import numpy as np
17-
1816
from onnxscript import ir
1917
from onnxscript.rewriter._basics import MatchResult
2018
from onnxscript.rewriter._ir_utils import get_numpy_value
@@ -51,8 +49,19 @@ def _check_expand_removable(
5149
) -> MatchResult:
5250
"""Check if an Expand node can be safely removed before a binary op.
5351
54-
The Expand is removable if the binary op's broadcasting produces the same
55-
output shape when using the original (pre-expand) tensor directly.
52+
The Expand node ``expanded_x = Expand(x, expand_shape)`` before a binary op
53+
``out = BinaryOp(expanded_x, y)`` can be removed when the binary op's
54+
own broadcasting produces the same output shape as the explicit expand.
55+
56+
The condition at each dimension ``i`` (right-aligned) is::
57+
58+
max(expand_shape[i], y[i]) == max(x[i], y[i])
59+
60+
which simplifies to: either ``x[i] == expand_shape[i]`` (expand is a no-op
61+
here) or ``y[i] == expand_shape[i]`` (y already covers the expansion).
62+
63+
This check works with dynamic (symbolic) dimensions in x or y as long as
64+
the expand target shape is a compile-time constant.
5665
5766
Args:
5867
expand_input: The value fed into the Expand node.
@@ -64,36 +73,58 @@ def _check_expand_removable(
6473
"""
6574
check_result = MatchResult()
6675

67-
# Need static shape info for both inputs.
76+
# Need at least the rank of both inputs.
6877
expand_input_shape = expand_input.shape
6978
other_shape = other_input.shape
7079
if expand_input_shape is None or other_shape is None:
7180
return check_result.fail("Input shapes are not statically known.")
7281

73-
# Require fully static (integer-only) shapes to avoid symbolic dim issues.
74-
if not expand_input_shape.is_static() or not other_shape.is_static():
75-
return check_result.fail("Input shapes are not fully static.")
76-
7782
# The Expand target shape must be a compile-time constant.
7883
expand_shape_val = get_numpy_value(shape)
7984
if expand_shape_val is None:
8085
return check_result.fail("Expand target shape is not a constant.")
8186

8287
expand_shape = tuple(int(v) for v in expand_shape_val.tolist())
83-
x_shape = tuple(int(d) for d in expand_input_shape)
84-
y_shape = tuple(int(d) for d in other_shape)
85-
86-
# Verify that removing the Expand does not change the binary op's output shape.
87-
try:
88-
result_with_expand = np.broadcast_shapes(expand_shape, y_shape)
89-
result_without_expand = np.broadcast_shapes(x_shape, y_shape)
90-
except ValueError:
91-
return check_result.fail("Shapes are not broadcastable.")
88+
expand_rank = len(expand_shape)
89+
x_rank = expand_input_shape.rank()
90+
y_rank = other_shape.rank()
91+
92+
# Check each dimension of expand_shape (right-aligned).
93+
# For the expand to be removable at position i, we need:
94+
# max(e_d, y_d) == max(x_d, y_d)
95+
# which requires: e_d <= max(x_d, y_d).
96+
# Since a valid Expand can only broadcast from 1 (not shrink), if e_d > 1
97+
# then x_d is either 1 or e_d. The condition then reduces to:
98+
# x_d == e_d OR y_d == e_d.
99+
for rev_i in range(expand_rank):
100+
i = expand_rank - 1 - rev_i
101+
e_d = expand_shape[i] # always a known integer
102+
103+
# If expand target is 1 at this dim, expand cannot shrink a dimension, so
104+
# x_d must also be 1. The output is max(1, y_d) = y_d in both cases.
105+
if e_d == 1:
106+
continue
107+
108+
# Get x dimension (virtually 1 if x has fewer dims than expand_shape).
109+
x_idx = x_rank - 1 - rev_i
110+
x_d = expand_input_shape[x_idx] if x_idx >= 0 else 1
111+
112+
# If x's dimension already equals the expand target, expand is a no-op here.
113+
if isinstance(x_d, int) and x_d == e_d:
114+
continue
115+
116+
# The expand is changing this dimension (x_d is 1 or symbolic).
117+
# For the binary op to yield the same output, y must supply this dimension.
118+
# Get y dimension (virtually 1 if y has fewer dims than expand_shape).
119+
y_idx = y_rank - 1 - rev_i
120+
y_d = other_shape[y_idx] if y_idx >= 0 else 1
121+
122+
if isinstance(y_d, int) and y_d == e_d:
123+
continue # y covers the expansion at this dimension
92124

93-
if result_with_expand != result_without_expand:
94125
return check_result.fail(
95-
f"Removing Expand would change output shape from "
96-
f"{result_with_expand} to {result_without_expand}."
126+
f"Cannot verify that removing Expand is safe at dimension {i}: "
127+
f"x_d={x_d!r}, expand_d={e_d}, y_d={y_d!r}."
97128
)
98129

99130
return check_result

onnxscript/rewriter/rules/common/_remove_expand_before_binary_op_test.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,12 @@ def test_expand_target_shape_not_constant_not_removed(self):
207207
count = mod.expand_before_binary_op_rules.apply_to_model(model)
208208
self.assertEqual(count, 0)
209209

210-
def test_expand_unknown_input_shape_not_removed(self):
211-
"""Expand cannot be removed when the input shape is not statically known."""
212-
# No shape annotation on 'x'
210+
def test_expand_removed_with_symbolic_x_static_y(self):
211+
"""Expand with a symbolic x dim can be removed when y statically covers the expansion.
212+
213+
x=[N], expand_shape=[3, 4], y=[3, 4]: since y provides all expand dimensions
214+
as known integers, the expand is redundant regardless of N's runtime value.
215+
"""
213216
model_text = """
214217
<ir_version: 7, opset_import: [ "" : 17]>
215218
agraph (float[N] x, float[3, 4] y) => (float[3, 4] output)
@@ -221,6 +224,25 @@ def test_expand_unknown_input_shape_not_removed(self):
221224
"""
222225
model = ir.from_onnx_text(model_text)
223226
count = mod.expand_before_binary_op_rules.apply_to_model(model)
227+
self.assertEqual(count, 1)
228+
229+
def test_expand_with_symbolic_y_dim_not_removed(self):
230+
"""Expand cannot be removed when y has a symbolic dim in a position where the
231+
expand is doing work and that symbolic dim cannot be verified to equal expand_d.
232+
"""
233+
# x=[3], expand_shape=[4, 3], y=[M, 3].
234+
# At dim 0 (expand adds dim 4): x_d=1 (virtual), y_d=M (symbolic) -> can't verify.
235+
model_text = """
236+
<ir_version: 7, opset_import: [ "" : 17]>
237+
agraph (float[3] x, float[M, 3] y) => (float[4, 3] output)
238+
<int64[2] shape = {4, 3}>
239+
{
240+
expanded = Expand(x, shape)
241+
output = Add(expanded, y)
242+
}
243+
"""
244+
model = ir.from_onnx_text(model_text)
245+
count = mod.expand_before_binary_op_rules.apply_to_model(model)
224246
self.assertEqual(count, 0)
225247

226248

0 commit comments

Comments
 (0)