1313
1414from __future__ import annotations
1515
16- import numpy as np
17-
1816from onnxscript import ir
1917from onnxscript .rewriter ._basics import MatchResult
2018from onnxscript .rewriter ._ir_utils import get_numpy_value
@@ -51,8 +49,19 @@ def _check_expand_removable(
5149) -> MatchResult :
5250 """Check if an Expand node can be safely removed before a binary op.
5351
54- The Expand is removable if the binary op's broadcasting produces the same
55- output shape when using the original (pre-expand) tensor directly.
52+ The Expand node ``expanded_x = Expand(x, expand_shape)`` before a binary op
53+ ``out = BinaryOp(expanded_x, y)`` can be removed when the binary op's
54+ own broadcasting produces the same output shape as the explicit expand.
55+
56+ The condition at each dimension ``i`` (right-aligned) is::
57+
58+ max(expand_shape[i], y[i]) == max(x[i], y[i])
59+
60+ which simplifies to: either ``x[i] == expand_shape[i]`` (expand is a no-op
61+ here) or ``y[i] == expand_shape[i]`` (y already covers the expansion).
62+
63+ This check works with dynamic (symbolic) dimensions in x or y as long as
64+ the expand target shape is a compile-time constant.
5665
5766 Args:
5867 expand_input: The value fed into the Expand node.
@@ -64,36 +73,58 @@ def _check_expand_removable(
6473 """
6574 check_result = MatchResult ()
6675
67- # Need static shape info for both inputs.
76+ # Need at least the rank of both inputs.
6877 expand_input_shape = expand_input .shape
6978 other_shape = other_input .shape
7079 if expand_input_shape is None or other_shape is None :
7180 return check_result .fail ("Input shapes are not statically known." )
7281
73- # Require fully static (integer-only) shapes to avoid symbolic dim issues.
74- if not expand_input_shape .is_static () or not other_shape .is_static ():
75- return check_result .fail ("Input shapes are not fully static." )
76-
7782 # The Expand target shape must be a compile-time constant.
7883 expand_shape_val = get_numpy_value (shape )
7984 if expand_shape_val is None :
8085 return check_result .fail ("Expand target shape is not a constant." )
8186
8287 expand_shape = tuple (int (v ) for v in expand_shape_val .tolist ())
83- x_shape = tuple (int (d ) for d in expand_input_shape )
84- y_shape = tuple (int (d ) for d in other_shape )
85-
86- # Verify that removing the Expand does not change the binary op's output shape.
87- try :
88- result_with_expand = np .broadcast_shapes (expand_shape , y_shape )
89- result_without_expand = np .broadcast_shapes (x_shape , y_shape )
90- except ValueError :
91- return check_result .fail ("Shapes are not broadcastable." )
88+ expand_rank = len (expand_shape )
89+ x_rank = expand_input_shape .rank ()
90+ y_rank = other_shape .rank ()
91+
92+ # Check each dimension of expand_shape (right-aligned).
93+ # For the expand to be removable at position i, we need:
94+ # max(e_d, y_d) == max(x_d, y_d)
95+ # which requires: e_d <= max(x_d, y_d).
96+ # Since a valid Expand can only broadcast from 1 (not shrink), if e_d > 1
97+ # then x_d is either 1 or e_d. The condition then reduces to:
98+ # x_d == e_d OR y_d == e_d.
99+ for rev_i in range (expand_rank ):
100+ i = expand_rank - 1 - rev_i
101+ e_d = expand_shape [i ] # always a known integer
102+
103+ # If expand target is 1 at this dim, expand cannot shrink a dimension, so
104+ # x_d must also be 1. The output is max(1, y_d) = y_d in both cases.
105+ if e_d == 1 :
106+ continue
107+
108+ # Get x dimension (virtually 1 if x has fewer dims than expand_shape).
109+ x_idx = x_rank - 1 - rev_i
110+ x_d = expand_input_shape [x_idx ] if x_idx >= 0 else 1
111+
112+ # If x's dimension already equals the expand target, expand is a no-op here.
113+ if isinstance (x_d , int ) and x_d == e_d :
114+ continue
115+
116+ # The expand is changing this dimension (x_d is 1 or symbolic).
117+ # For the binary op to yield the same output, y must supply this dimension.
118+ # Get y dimension (virtually 1 if y has fewer dims than expand_shape).
119+ y_idx = y_rank - 1 - rev_i
120+ y_d = other_shape [y_idx ] if y_idx >= 0 else 1
121+
122+ if isinstance (y_d , int ) and y_d == e_d :
123+ continue # y covers the expansion at this dimension
92124
93- if result_with_expand != result_without_expand :
94125 return check_result .fail (
95- f"Removing Expand would change output shape from "
96- f"{ result_with_expand } to { result_without_expand } ."
126+ f"Cannot verify that removing Expand is safe at dimension { i } : "
127+ f"x_d= { x_d !r } , expand_d= { e_d } , y_d= { y_d !r } ."
97128 )
98129
99130 return check_result
0 commit comments