Skip to content

Commit 8437137

Browse files
Copilotxadupre
andcommitted
Remove producer-specific shape tracing; use shape annotations and broadcast comparison
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/4d4f4fb8-b66e-456e-a1d6-b1eb5ca1b532
1 parent ed31552 commit 8437137

2 files changed

Lines changed: 171 additions & 167 deletions

File tree

onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py

Lines changed: 138 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -42,178 +42,170 @@
4242
)
4343

4444

45-
def _get_shape_tensor_length(shape_value: ir.Value) -> int | None:
46-
"""Try to determine the number of elements in a 1-D shape tensor.
45+
def _compute_broadcast_dim(d1, d2):
46+
"""Return the numpy broadcast of two dimension values.
4747
48-
Returns the length as an int, or ``None`` if it cannot be determined.
48+
Each dimension value may be an ``int`` or an ``onnx_ir.SymbolicDim``.
49+
Returns ``None`` when the result cannot be determined statically (e.g. two
50+
distinct symbolic values neither of which is known to be 1).
4951
"""
50-
const = get_numpy_value(shape_value)
51-
if const is not None:
52-
return len(const)
53-
54-
# Use the tensor's own shape annotation (should be 1-D).
55-
tensor_shape = shape_value.shape
56-
if tensor_shape is not None and tensor_shape.rank() == 1:
57-
dim = tensor_shape[0]
58-
if isinstance(dim, int):
59-
return dim
60-
61-
# Trace through Concat and Shape nodes.
62-
producer = shape_value.producer()
63-
if producer is None:
64-
return None
52+
if d1 == 1:
53+
return d2
54+
if d2 == 1:
55+
return d1
56+
if d1 == d2:
57+
return d1
58+
return None
6559

66-
if producer.op_type == "Concat":
67-
total = 0
68-
for inp in producer.inputs:
69-
if inp is None:
70-
return None
71-
seg_len = _get_shape_tensor_length(inp)
72-
if seg_len is None:
73-
return None
74-
total += seg_len
75-
return total
76-
77-
if producer.op_type == "Shape":
78-
x_input = producer.inputs[0] if producer.inputs else None
79-
if x_input is None:
80-
return None
81-
start_attr = producer.attributes.get("start")
82-
end_attr = producer.attributes.get("end")
83-
start = start_attr.value if start_attr is not None else 0
84-
if end_attr is not None:
85-
return end_attr.value - start
86-
# end defaults to rank of x
87-
if x_input.shape is not None:
88-
x_rank = x_input.shape.rank()
89-
if x_rank is not None:
90-
return x_rank - start
60+
61+
def _compute_broadcast_shape(shape1: ir.Shape, shape2: ir.Shape) -> list | None:
62+
"""Compute numpy-style broadcast shape symbolically.
63+
64+
Returns the broadcast shape as a list of dimension values (``int`` or
65+
``SymbolicDim``), or ``None`` when the result cannot be determined (e.g.
66+
unknown ranks or incompatible static dims).
67+
"""
68+
rank1 = shape1.rank()
69+
rank2 = shape2.rank()
70+
if rank1 is None or rank2 is None:
9171
return None
72+
rank = max(rank1, rank2)
73+
result = []
74+
for i in range(rank):
75+
idx1 = rank1 - rank + i
76+
d1 = shape1[idx1] if idx1 >= 0 else 1
77+
idx2 = rank2 - rank + i
78+
d2 = shape2[idx2] if idx2 >= 0 else 1
79+
d = _compute_broadcast_dim(d1, d2)
80+
if d is None:
81+
return None
82+
result.append(d)
83+
return result
9284

93-
return None
9485

86+
def _check_dims_sufficient(
87+
expand_shape: ir.Shape,
88+
x_shape: ir.Shape,
89+
y_shape: ir.Shape,
90+
) -> MatchResult:
91+
"""Check that x and y together cover every dimension of the expand target.
9592
96-
def _get_dim_from_shape_value(shape_value: ir.Value, index: int):
97-
"""Try to extract the ``index``-th element from a 1-D shape tensor.
93+
For each dimension ``i`` of *expand_shape* (right-aligned) the expand is
94+
considered redundant when at least one of the following holds:
9895
99-
This traces the computation graph through ``Concat`` and ``Shape`` nodes
100-
to resolve individual elements without requiring the whole tensor to be a
101-
compile-time constant.
96+
- ``expand_shape[i] == 1`` - expand cannot shrink a dim, so ``x_d`` must
97+
also be 1 and both with and without expand produce ``y_d``.
98+
- ``x_d == expand_shape[i]`` - the expand is a no-op at this dim.
99+
- ``y_d == expand_shape[i]`` - ``y`` already supplies this expansion.
102100
103-
Returns an ``int``, a ``SymbolicDim``, or ``None`` if the element cannot
104-
be determined.
101+
Comparisons work for both ``int`` and ``SymbolicDim`` values.
105102
"""
106-
const = get_numpy_value(shape_value)
107-
if const is not None:
108-
if 0 <= index < len(const):
109-
return int(const[index])
110-
return None
103+
check_result = MatchResult()
104+
e_rank = expand_shape.rank()
105+
x_rank = x_shape.rank()
106+
y_rank = y_shape.rank()
107+
if e_rank is None:
108+
return check_result.fail("Expand output rank is unknown.")
111109

112-
producer = shape_value.producer()
113-
if producer is None:
114-
return None # graph input or initializer, can't trace
115-
116-
if producer.op_type == "Concat":
117-
offset = 0
118-
for inp in producer.inputs:
119-
if inp is None:
120-
return None
121-
seg_len = _get_shape_tensor_length(inp)
122-
if seg_len is None:
123-
return None
124-
if offset <= index < offset + seg_len:
125-
return _get_dim_from_shape_value(inp, index - offset)
126-
offset += seg_len
127-
return None
110+
for rev_i in range(e_rank):
111+
i = e_rank - 1 - rev_i
112+
e_d = expand_shape[i]
128113

129-
if producer.op_type == "Shape":
130-
x_input = producer.inputs[0] if producer.inputs else None
131-
if x_input is None:
132-
return None
133-
x_shape = x_input.shape
134-
if x_shape is None:
135-
return None
136-
start_attr = producer.attributes.get("start")
137-
start = start_attr.value if start_attr is not None else 0
138-
actual_idx = start + index
139-
x_rank = x_shape.rank()
140-
if x_rank is not None and 0 <= actual_idx < x_rank:
141-
return x_shape[actual_idx] # int or SymbolicDim
142-
return None
114+
if isinstance(e_d, int) and e_d == 1:
115+
continue # expand cannot shrink; x_d is also 1, no-op
143116

144-
return None
117+
x_idx = x_rank - 1 - rev_i
118+
x_d = x_shape[x_idx] if x_idx >= 0 else 1
119+
if x_d == e_d:
120+
continue # expand is a no-op at this dimension
121+
122+
y_idx = y_rank - 1 - rev_i
123+
y_d = y_shape[y_idx] if y_idx >= 0 else 1
124+
if y_d == e_d:
125+
continue # y already supplies this dimension
126+
127+
return check_result.fail(
128+
f"Cannot verify that removing Expand is safe at dimension {i}: "
129+
f"x_d={x_d!r}, expand_d={e_d!r}, y_d={y_d!r}."
130+
)
131+
132+
return check_result
145133

146134

147135
def _check_expand_removable(
148136
expand_input: ir.Value,
149137
shape: ir.Value,
150138
other_input: ir.Value,
139+
expand_output: ir.Value | None = None,
140+
binary_op_output: ir.Value | None = None,
151141
) -> MatchResult:
152142
"""Check if an Expand node can be safely removed before a binary op.
153143
154-
The Expand node ``expanded_x = Expand(x, expand_shape)`` before a binary op
155-
``out = BinaryOp(expanded_x, y)`` can be removed when the binary op's
156-
own broadcasting produces the same output shape as the explicit expand.
144+
The Expand ``expanded_x = Expand(x, expand_shape)`` before a binary op
145+
``out = BinaryOp(expanded_x, y)`` is redundant when the binary op's own
146+
broadcasting produces the same output as if the expand had been applied.
157147
158-
Two strategies are tried in order:
148+
Three strategies are tried in order:
159149
160-
1. **Constant expand shape**: When the expand target shape is a compile-time
161-
constant, each dimension is checked individually (right-aligned). At
162-
dimension ``i`` the expand is safe to remove if any of the following hold:
150+
1. **Constant expand shape** - When ``shape`` is a compile-time constant,
151+
the dimension values are extracted from it and the check is performed
152+
directly.
163153
164-
- ``expand_shape[i] == 1`` - expand can never shrink a dim, so x_d is
165-
also 1 and both paths produce ``y_d``.
166-
- ``x_d == expand_shape[i]`` - expand is a no-op here.
167-
- ``y_d == expand_shape[i]`` - y already covers the expansion.
154+
2. **Expand output shape annotation** - When ``shape`` is dynamic but the
155+
Expand node's output value already carries a shape annotation (e.g.
156+
after ONNX shape inference has been applied to the model), those
157+
dimension values are used for the check.
168158
169-
2. **Dynamic expand shape**: When the target shape is not a compile-time
170-
constant, the rule traces through ``Shape`` and ``Concat`` nodes to
171-
extract individual dimension values from the shape tensor. The same
172-
dimension-by-dimension safety check is then applied. This handles
173-
patterns such as ``Expand(x, Concat(Shape(x, 0, 1), Shape(x, 1, 2)))``
174-
where the expand is provably a no-op.
159+
3. **Binary op output shape** - When neither of the above is available,
160+
the rule verifies that ``broadcast(x.shape, y.shape)`` symbolically
161+
equals the binary op's output shape. If they agree, the binary op's
162+
own broadcasting already accounts for all the expansion and the
163+
Expand is redundant.
175164
176165
Args:
177166
expand_input: The value fed into the Expand node (``x``).
178167
shape: The target shape operand of the Expand node.
179168
other_input: The other operand of the binary op (``y``).
169+
expand_output: The output value of the Expand node. Required for
170+
strategy 2.
171+
binary_op_output: The output value of the binary op. Required for
172+
strategy 3.
180173
181174
Returns:
182-
A MatchResult that is successful when the Expand can be removed.
175+
A :class:`MatchResult` that is successful when the Expand can be
176+
removed.
183177
"""
184178
check_result = MatchResult()
185179

186-
expand_input_shape = expand_input.shape
187-
other_shape = other_input.shape
188-
if expand_input_shape is None or other_shape is None:
180+
x_shape = expand_input.shape
181+
y_shape = other_input.shape
182+
if x_shape is None or y_shape is None:
189183
return check_result.fail("Input shapes are not known.")
190184

191-
x_rank = expand_input_shape.rank()
192-
y_rank = other_shape.rank()
185+
x_rank = x_shape.rank()
186+
y_rank = y_shape.rank()
193187

194-
# --- Path 1: expand target shape is a compile-time constant ---
188+
# --- Strategy 1: expand target shape is a compile-time constant ---
195189
expand_shape_val = get_numpy_value(shape)
196190
if expand_shape_val is not None:
197191
expand_shape = tuple(int(v) for v in expand_shape_val.tolist())
198192
expand_rank = len(expand_shape)
199193

200194
for rev_i in range(expand_rank):
201195
i = expand_rank - 1 - rev_i
202-
e_d = expand_shape[i] # always a known integer
196+
e_d = expand_shape[i] # always a known integer from numpy
203197

204-
# expand cannot shrink a dim, so x_d must also be 1 here;
205-
# both with and without expand the output is y_d.
206198
if e_d == 1:
207-
continue
199+
continue # expand cannot shrink; x_d is also 1, no-op
208200

209201
x_idx = x_rank - 1 - rev_i
210-
x_d = expand_input_shape[x_idx] if x_idx >= 0 else 1
202+
x_d = x_shape[x_idx] if x_idx >= 0 else 1
211203

212204
if isinstance(x_d, int) and x_d == e_d:
213205
continue # expand is a no-op at this dimension
214206

215207
y_idx = y_rank - 1 - rev_i
216-
y_d = other_shape[y_idx] if y_idx >= 0 else 1
208+
y_d = y_shape[y_idx] if y_idx >= 0 else 1
217209

218210
if isinstance(y_d, int) and y_d == e_d:
219211
continue # y already supplies this dimension
@@ -225,44 +217,28 @@ def _check_expand_removable(
225217

226218
return check_result
227219

228-
# --- Path 2: expand target shape is dynamic ---
229-
# Trace through Shape/Concat nodes to extract individual elements of the
230-
# shape tensor, then apply the same dimension-by-dimension check.
231-
expand_rank = _get_shape_tensor_length(shape)
232-
if expand_rank is None:
220+
# --- Strategy 2: Expand output shape is known (e.g. from shape inference) ---
221+
if expand_output is not None and expand_output.shape is not None:
222+
return _check_dims_sufficient(expand_output.shape, x_shape, y_shape)
223+
224+
# --- Strategy 3: use the binary op's output shape ---
225+
# broadcast(x.shape, y.shape) must equal the binary op's output shape.
226+
# If it does, the binary op's own broadcasting already produces the same
227+
# result as first expanding x and then broadcasting.
228+
if binary_op_output is not None and binary_op_output.shape is not None:
229+
op_output_shape = binary_op_output.shape
230+
if op_output_shape.rank() is not None:
231+
computed = _compute_broadcast_shape(x_shape, y_shape)
232+
if computed is not None and len(computed) == op_output_shape.rank():
233+
if all(c == a for c, a in zip(computed, op_output_shape)):
234+
return check_result
233235
return check_result.fail(
234-
"Expand target shape is dynamic and its length cannot be determined."
236+
"broadcast(x.shape, y.shape) does not match the binary op output shape."
235237
)
236238

237-
for i in range(expand_rank):
238-
e_d = _get_dim_from_shape_value(shape, i)
239-
if e_d is None:
240-
return check_result.fail(
241-
f"Cannot determine expand shape at dimension {i}."
242-
)
243-
244-
if isinstance(e_d, int) and e_d == 1:
245-
continue # expand is a no-op at this dimension
246-
247-
x_idx = x_rank - expand_rank + i
248-
x_d = expand_input_shape[x_idx] if x_idx >= 0 else 1
249-
250-
# e_d == x_d works for both int and SymbolicDim (same symbolic name).
251-
if x_d == e_d:
252-
continue # expand is a no-op at this dimension
253-
254-
y_idx = y_rank - expand_rank + i
255-
y_d = other_shape[y_idx] if y_idx >= 0 else 1
256-
257-
if y_d == e_d:
258-
continue # y already supplies this dimension
259-
260-
return check_result.fail(
261-
f"Cannot verify that removing Expand is safe at dimension {i}: "
262-
f"x_d={x_d!r}, expand_d={e_d!r}, y_d={y_d!r}."
263-
)
264-
265-
return check_result
239+
return check_result.fail(
240+
"Expand target shape is not a constant and no shape annotations are available."
241+
)
266242

267243

268244
class _ExpandFirstInput(RewriteRuleClassBase):
@@ -276,8 +252,11 @@ def pattern(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
276252
return getattr(op, self._op_type)(op.Expand(x, shape), y)
277253

278254
def check(self, context, x: ir.Value, shape: ir.Value, y: ir.Value) -> MatchResult:
279-
del context # Unused
280-
return _check_expand_removable(x, shape, y)
255+
expand_output = context.root.inputs[0] if context.root.inputs else None
256+
binary_op_output = context.root.outputs[0] if context.root.outputs else None
257+
return _check_expand_removable(
258+
x, shape, y, expand_output=expand_output, binary_op_output=binary_op_output
259+
)
281260

282261
def rewrite(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
283262
return getattr(op, self._op_type)(x, y)
@@ -294,8 +273,11 @@ def pattern(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
294273
return getattr(op, self._op_type)(x, op.Expand(y, shape))
295274

296275
def check(self, context, x: ir.Value, y: ir.Value, shape: ir.Value) -> MatchResult:
297-
del context # Unused
298-
return _check_expand_removable(y, shape, x)
276+
expand_output = context.root.inputs[1] if context.root.inputs else None
277+
binary_op_output = context.root.outputs[0] if context.root.outputs else None
278+
return _check_expand_removable(
279+
y, shape, x, expand_output=expand_output, binary_op_output=binary_op_output
280+
)
299281

300282
def rewrite(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
301283
return getattr(op, self._op_type)(x, y)

0 commit comments

Comments
 (0)