4242)
4343
4444
45- def _get_shape_tensor_length ( shape_value : ir . Value ) -> int | None :
46- """Try to determine the number of elements in a 1-D shape tensor .
45+ def _compute_broadcast_dim ( d1 , d2 ) :
46+ """Return the numpy broadcast of two dimension values .
4747
48- Returns the length as an int, or ``None`` if it cannot be determined.
48+ Each dimension value may be an ``int`` or an ``onnx_ir.SymbolicDim``.
49+ Returns ``None`` when the result cannot be determined statically (e.g. two
50+ distinct symbolic values neither of which is known to be 1).
4951 """
50- const = get_numpy_value (shape_value )
51- if const is not None :
52- return len (const )
53-
54- # Use the tensor's own shape annotation (should be 1-D).
55- tensor_shape = shape_value .shape
56- if tensor_shape is not None and tensor_shape .rank () == 1 :
57- dim = tensor_shape [0 ]
58- if isinstance (dim , int ):
59- return dim
60-
61- # Trace through Concat and Shape nodes.
62- producer = shape_value .producer ()
63- if producer is None :
64- return None
52+ if d1 == 1 :
53+ return d2
54+ if d2 == 1 :
55+ return d1
56+ if d1 == d2 :
57+ return d1
58+ return None
6559
66- if producer .op_type == "Concat" :
67- total = 0
68- for inp in producer .inputs :
69- if inp is None :
70- return None
71- seg_len = _get_shape_tensor_length (inp )
72- if seg_len is None :
73- return None
74- total += seg_len
75- return total
76-
77- if producer .op_type == "Shape" :
78- x_input = producer .inputs [0 ] if producer .inputs else None
79- if x_input is None :
80- return None
81- start_attr = producer .attributes .get ("start" )
82- end_attr = producer .attributes .get ("end" )
83- start = start_attr .value if start_attr is not None else 0
84- if end_attr is not None :
85- return end_attr .value - start
86- # end defaults to rank of x
87- if x_input .shape is not None :
88- x_rank = x_input .shape .rank ()
89- if x_rank is not None :
90- return x_rank - start
60+
61+ def _compute_broadcast_shape (shape1 : ir .Shape , shape2 : ir .Shape ) -> list | None :
62+ """Compute numpy-style broadcast shape symbolically.
63+
64+ Returns the broadcast shape as a list of dimension values (``int`` or
65+ ``SymbolicDim``), or ``None`` when the result cannot be determined (e.g.
66+ unknown ranks or incompatible static dims).
67+ """
68+ rank1 = shape1 .rank ()
69+ rank2 = shape2 .rank ()
70+ if rank1 is None or rank2 is None :
9171 return None
72+ rank = max (rank1 , rank2 )
73+ result = []
74+ for i in range (rank ):
75+ idx1 = rank1 - rank + i
76+ d1 = shape1 [idx1 ] if idx1 >= 0 else 1
77+ idx2 = rank2 - rank + i
78+ d2 = shape2 [idx2 ] if idx2 >= 0 else 1
79+ d = _compute_broadcast_dim (d1 , d2 )
80+ if d is None :
81+ return None
82+ result .append (d )
83+ return result
9284
93- return None
9485
86+ def _check_dims_sufficient (
87+ expand_shape : ir .Shape ,
88+ x_shape : ir .Shape ,
89+ y_shape : ir .Shape ,
90+ ) -> MatchResult :
91+ """Check that x and y together cover every dimension of the expand target.
9592
96- def _get_dim_from_shape_value ( shape_value : ir . Value , index : int ):
97- """Try to extract the ``index``-th element from a 1-D shape tensor.
93+ For each dimension ``i`` of *expand_shape* (right-aligned) the expand is
94+ considered redundant when at least one of the following holds:
9895
99- This traces the computation graph through ``Concat`` and ``Shape`` nodes
100- to resolve individual elements without requiring the whole tensor to be a
101- compile-time constant.
96+ - ``expand_shape[i] == 1`` - expand cannot shrink a dim, so ``x_d`` must
97+ also be 1 and both with and without expand produce ``y_d``.
98+ - ``x_d == expand_shape[i]`` - the expand is a no-op at this dim.
99+ - ``y_d == expand_shape[i]`` - ``y`` already supplies this expansion.
102100
103- Returns an ``int``, a ``SymbolicDim``, or ``None`` if the element cannot
104- be determined.
101+ Comparisons work for both ``int`` and ``SymbolicDim`` values.
105102 """
106- const = get_numpy_value (shape_value )
107- if const is not None :
108- if 0 <= index < len (const ):
109- return int (const [index ])
110- return None
103+ check_result = MatchResult ()
104+ e_rank = expand_shape .rank ()
105+ x_rank = x_shape .rank ()
106+ y_rank = y_shape .rank ()
107+ if e_rank is None :
108+ return check_result .fail ("Expand output rank is unknown." )
111109
112- producer = shape_value .producer ()
113- if producer is None :
114- return None # graph input or initializer, can't trace
115-
116- if producer .op_type == "Concat" :
117- offset = 0
118- for inp in producer .inputs :
119- if inp is None :
120- return None
121- seg_len = _get_shape_tensor_length (inp )
122- if seg_len is None :
123- return None
124- if offset <= index < offset + seg_len :
125- return _get_dim_from_shape_value (inp , index - offset )
126- offset += seg_len
127- return None
110+ for rev_i in range (e_rank ):
111+ i = e_rank - 1 - rev_i
112+ e_d = expand_shape [i ]
128113
129- if producer .op_type == "Shape" :
130- x_input = producer .inputs [0 ] if producer .inputs else None
131- if x_input is None :
132- return None
133- x_shape = x_input .shape
134- if x_shape is None :
135- return None
136- start_attr = producer .attributes .get ("start" )
137- start = start_attr .value if start_attr is not None else 0
138- actual_idx = start + index
139- x_rank = x_shape .rank ()
140- if x_rank is not None and 0 <= actual_idx < x_rank :
141- return x_shape [actual_idx ] # int or SymbolicDim
142- return None
114+ if isinstance (e_d , int ) and e_d == 1 :
115+ continue # expand cannot shrink; x_d is also 1, no-op
143116
144- return None
117+ x_idx = x_rank - 1 - rev_i
118+ x_d = x_shape [x_idx ] if x_idx >= 0 else 1
119+ if x_d == e_d :
120+ continue # expand is a no-op at this dimension
121+
122+ y_idx = y_rank - 1 - rev_i
123+ y_d = y_shape [y_idx ] if y_idx >= 0 else 1
124+ if y_d == e_d :
125+ continue # y already supplies this dimension
126+
127+ return check_result .fail (
128+ f"Cannot verify that removing Expand is safe at dimension { i } : "
129+ f"x_d={ x_d !r} , expand_d={ e_d !r} , y_d={ y_d !r} ."
130+ )
131+
132+ return check_result
145133
146134
147135def _check_expand_removable (
148136 expand_input : ir .Value ,
149137 shape : ir .Value ,
150138 other_input : ir .Value ,
139+ expand_output : ir .Value | None = None ,
140+ binary_op_output : ir .Value | None = None ,
151141) -> MatchResult :
152142 """Check if an Expand node can be safely removed before a binary op.
153143
154- The Expand node ``expanded_x = Expand(x, expand_shape)`` before a binary op
155- ``out = BinaryOp(expanded_x, y)`` can be removed when the binary op's
156- own broadcasting produces the same output shape as the explicit expand.
144+ The Expand ``expanded_x = Expand(x, expand_shape)`` before a binary op
145+ ``out = BinaryOp(expanded_x, y)`` is redundant when the binary op's own
146+ broadcasting produces the same output as if the expand had been applied .
157147
158- Two strategies are tried in order:
148+ Three strategies are tried in order:
159149
160- 1. **Constant expand shape**: When the expand target shape is a compile-time
161- constant, each dimension is checked individually (right-aligned). At
162- dimension ``i`` the expand is safe to remove if any of the following hold:
150+ 1. **Constant expand shape** - When `` shape`` is a compile-time constant,
151+ the dimension values are extracted from it and the check is performed
152+ directly.
163153
164- - ``expand_shape[i] == 1 `` - expand can never shrink a dim, so x_d is
165- also 1 and both paths produce ``y_d`` .
166- - ``x_d == expand_shape[i]`` - expand is a no-op here.
167- - ``y_d == expand_shape[i]`` - y already covers the expansion .
154+ 2. **Expand output shape annotation** - When ``shape `` is dynamic but the
155+ Expand node's output value already carries a shape annotation (e.g .
156+ after ONNX shape inference has been applied to the model), those
157+ dimension values are used for the check .
168158
169- 2. **Dynamic expand shape**: When the target shape is not a compile-time
170- constant, the rule traces through ``Shape`` and ``Concat`` nodes to
171- extract individual dimension values from the shape tensor. The same
172- dimension-by-dimension safety check is then applied. This handles
173- patterns such as ``Expand(x, Concat(Shape(x, 0, 1), Shape(x, 1, 2)))``
174- where the expand is provably a no-op.
159+ 3. **Binary op output shape** - When neither of the above is available,
160+ the rule verifies that ``broadcast(x.shape, y.shape)`` symbolically
161+ equals the binary op's output shape. If they agree, the binary op's
162+ own broadcasting already accounts for all the expansion and the
163+ Expand is redundant.
175164
176165 Args:
177166 expand_input: The value fed into the Expand node (``x``).
178167 shape: The target shape operand of the Expand node.
179168 other_input: The other operand of the binary op (``y``).
169+ expand_output: The output value of the Expand node. Required for
170+ strategy 2.
171+ binary_op_output: The output value of the binary op. Required for
172+ strategy 3.
180173
181174 Returns:
182- A MatchResult that is successful when the Expand can be removed.
175+ A :class:`MatchResult` that is successful when the Expand can be
176+ removed.
183177 """
184178 check_result = MatchResult ()
185179
186- expand_input_shape = expand_input .shape
187- other_shape = other_input .shape
188- if expand_input_shape is None or other_shape is None :
180+ x_shape = expand_input .shape
181+ y_shape = other_input .shape
182+ if x_shape is None or y_shape is None :
189183 return check_result .fail ("Input shapes are not known." )
190184
191- x_rank = expand_input_shape .rank ()
192- y_rank = other_shape .rank ()
185+ x_rank = x_shape .rank ()
186+ y_rank = y_shape .rank ()
193187
194- # --- Path 1: expand target shape is a compile-time constant ---
188+ # --- Strategy 1: expand target shape is a compile-time constant ---
195189 expand_shape_val = get_numpy_value (shape )
196190 if expand_shape_val is not None :
197191 expand_shape = tuple (int (v ) for v in expand_shape_val .tolist ())
198192 expand_rank = len (expand_shape )
199193
200194 for rev_i in range (expand_rank ):
201195 i = expand_rank - 1 - rev_i
202- e_d = expand_shape [i ] # always a known integer
196+ e_d = expand_shape [i ] # always a known integer from numpy
203197
204- # expand cannot shrink a dim, so x_d must also be 1 here;
205- # both with and without expand the output is y_d.
206198 if e_d == 1 :
207- continue
199+ continue # expand cannot shrink; x_d is also 1, no-op
208200
209201 x_idx = x_rank - 1 - rev_i
210- x_d = expand_input_shape [x_idx ] if x_idx >= 0 else 1
202+ x_d = x_shape [x_idx ] if x_idx >= 0 else 1
211203
212204 if isinstance (x_d , int ) and x_d == e_d :
213205 continue # expand is a no-op at this dimension
214206
215207 y_idx = y_rank - 1 - rev_i
216- y_d = other_shape [y_idx ] if y_idx >= 0 else 1
208+ y_d = y_shape [y_idx ] if y_idx >= 0 else 1
217209
218210 if isinstance (y_d , int ) and y_d == e_d :
219211 continue # y already supplies this dimension
@@ -225,44 +217,28 @@ def _check_expand_removable(
225217
226218 return check_result
227219
228- # --- Path 2: expand target shape is dynamic ---
229- # Trace through Shape/Concat nodes to extract individual elements of the
230- # shape tensor, then apply the same dimension-by-dimension check.
231- expand_rank = _get_shape_tensor_length (shape )
232- if expand_rank is None :
220+ # --- Strategy 2: Expand output shape is known (e.g. from shape inference) ---
221+ if expand_output is not None and expand_output .shape is not None :
222+ return _check_dims_sufficient (expand_output .shape , x_shape , y_shape )
223+
224+ # --- Strategy 3: use the binary op's output shape ---
225+ # broadcast(x.shape, y.shape) must equal the binary op's output shape.
226+ # If it does, the binary op's own broadcasting already produces the same
227+ # result as first expanding x and then broadcasting.
228+ if binary_op_output is not None and binary_op_output .shape is not None :
229+ op_output_shape = binary_op_output .shape
230+ if op_output_shape .rank () is not None :
231+ computed = _compute_broadcast_shape (x_shape , y_shape )
232+ if computed is not None and len (computed ) == op_output_shape .rank ():
233+ if all (c == a for c , a in zip (computed , op_output_shape )):
234+ return check_result
233235 return check_result .fail (
234- "Expand target shape is dynamic and its length cannot be determined ."
236+ "broadcast(x.shape, y. shape) does not match the binary op output shape ."
235237 )
236238
237- for i in range (expand_rank ):
238- e_d = _get_dim_from_shape_value (shape , i )
239- if e_d is None :
240- return check_result .fail (
241- f"Cannot determine expand shape at dimension { i } ."
242- )
243-
244- if isinstance (e_d , int ) and e_d == 1 :
245- continue # expand is a no-op at this dimension
246-
247- x_idx = x_rank - expand_rank + i
248- x_d = expand_input_shape [x_idx ] if x_idx >= 0 else 1
249-
250- # e_d == x_d works for both int and SymbolicDim (same symbolic name).
251- if x_d == e_d :
252- continue # expand is a no-op at this dimension
253-
254- y_idx = y_rank - expand_rank + i
255- y_d = other_shape [y_idx ] if y_idx >= 0 else 1
256-
257- if y_d == e_d :
258- continue # y already supplies this dimension
259-
260- return check_result .fail (
261- f"Cannot verify that removing Expand is safe at dimension { i } : "
262- f"x_d={ x_d !r} , expand_d={ e_d !r} , y_d={ y_d !r} ."
263- )
264-
265- return check_result
239+ return check_result .fail (
240+ "Expand target shape is not a constant and no shape annotations are available."
241+ )
266242
267243
268244class _ExpandFirstInput (RewriteRuleClassBase ):
@@ -276,8 +252,11 @@ def pattern(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
276252 return getattr (op , self ._op_type )(op .Expand (x , shape ), y )
277253
278254 def check (self , context , x : ir .Value , shape : ir .Value , y : ir .Value ) -> MatchResult :
279- del context # Unused
280- return _check_expand_removable (x , shape , y )
255+ expand_output = context .root .inputs [0 ] if context .root .inputs else None
256+ binary_op_output = context .root .outputs [0 ] if context .root .outputs else None
257+ return _check_expand_removable (
258+ x , shape , y , expand_output = expand_output , binary_op_output = binary_op_output
259+ )
281260
282261 def rewrite (self , op , x : ir .Value , shape : ir .Value , y : ir .Value ) -> ir .Value :
283262 return getattr (op , self ._op_type )(x , y )
@@ -294,8 +273,11 @@ def pattern(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
294273 return getattr (op , self ._op_type )(x , op .Expand (y , shape ))
295274
296275 def check (self , context , x : ir .Value , y : ir .Value , shape : ir .Value ) -> MatchResult :
297- del context # Unused
298- return _check_expand_removable (y , shape , x )
276+ expand_output = context .root .inputs [1 ] if context .root .inputs else None
277+ binary_op_output = context .root .outputs [0 ] if context .root .outputs else None
278+ return _check_expand_removable (
279+ y , shape , x , expand_output = expand_output , binary_op_output = binary_op_output
280+ )
299281
300282 def rewrite (self , op , x : ir .Value , y : ir .Value , shape : ir .Value ) -> ir .Value :
301283 return getattr (op , self ._op_type )(x , y )
0 commit comments