1+ # Copyright (c) Microsoft Corporation.
2+ # Licensed under the MIT License.
3+ """Basic rewrite rules for general optimization patterns.
4+
5+ This module contains fundamental optimization rules that are generally applicable
6+ to most ONNX models, including cast elimination, transpose simplification,
7+ shape operation fusion, and other common patterns.
8+ """
9+ from __future__ import annotations
10+
11+ from typing import ClassVar , Sequence
12+
13+ from onnxscript import ir
14+ from onnxscript .rewriter import _ir_utils as ir_utils
15+ from onnxscript .rewriter import pattern as orp
16+
17+
18+ class SqueezeReshape (orp .RewriteRuleClassBase ):
19+ """Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
20+
21+ This pattern arises from the translation of pytorch symints.
22+ """
23+
24+ def __init__ (self ):
25+ super ().__init__ ("SqueezeReshape1d" , remove_nodes = False )
26+
27+ def pattern (self , op , x ):
28+ return op .Reshape (op .Squeeze (x ), [- 1 ])
29+
30+ def rewrite (self , op , x : ir .Value ):
31+ return op .Identity (x )
32+
33+ def check (self , context , x ) -> orp .MatchResult :
34+ del context # Unused
35+ check_result = orp .MatchResult ()
36+ if not ir_utils .has_rank (x , 1 ):
37+ return check_result .fail ("Input is not 1D" )
38+ return check_result
39+
40+
41+ class CastIdentity (orp .RewriteRuleClassBase ):
42+ """Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
43+
44+ def pattern (self , op , x , to ):
45+ return op .Cast (x , to = to )
46+
47+ def rewrite (self , op , x : ir .Value , to : ir .Attr ):
48+ return op .Identity (x )
49+
50+ def check (self , context , x , to ) -> orp .MatchResult :
51+ check_result = orp .MatchResult ()
52+ if x .dtype != to .as_int ():
53+ return check_result .fail ("Input and output types are not the same" )
54+ return check_result
55+
56+
57+ class CastCast (orp .RewriteRuleClassBase ):
58+ """Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
59+
60+ # Simplify "cast type1 => type2 => type3" to "cast type1 => type3".
61+ # This rule is not valid for all combinations of types: e.g.,
62+ # it is not valid for float32 => float16 => float32 or float32 => int32 => string.
63+ # TODO: fill out the list of allowed combinations: the following is just a couple
64+ # that shows up in practice where it is valid
65+ _allowed_type2_type3 : ClassVar = frozenset (
66+ {
67+ (ir .DataType .FLOAT , ir .DataType .FLOAT16 ),
68+ (ir .DataType .FLOAT , ir .DataType .BFLOAT16 ),
69+ }
70+ )
71+
72+ def pattern (self , op , x , to , to_ignored ):
73+ return op .Cast (op .Cast (x , to = to_ignored ), to = to )
74+
75+ def check (self , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
76+ check_result = orp .MatchResult ()
77+ type2 = to_ignored .as_int ()
78+ type3 = to .as_int ()
79+ if (type2 , type3 ) not in self ._allowed_type2_type3 :
80+ return check_result .fail (
81+ f"Intermediate cast elimination not recognized as valid from { type2 } to { type3 } . "
82+ f"Cast-Cast rule may be incomplete for this combination."
83+ )
84+ return check_result
85+
86+ def rewrite (self , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
87+ return op .Cast (x , to = to )
88+
89+
90+ class ExpandIdentity (orp .RewriteRuleClassBase ):
91+ """Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
92+
93+ def pattern (self , op , x , shape ):
94+ return op .Expand (x , shape )
95+
96+ def rewrite (self , op , x : ir .Value , shape : ir .Value ):
97+ return op .Identity (x )
98+
99+ def check (self , context , x , shape ) -> orp .MatchResult :
100+ check_result = orp .MatchResult ()
101+ if shape .const_value is None :
102+ # Shape is not a constant and cannot be guessed.
103+ return check_result .fail ("Shape is not a constant and cannot be guessed." )
104+ if (x_shape := x .shape ) is None :
105+ # We don't know the shape of the input
106+ return check_result .fail ("Input shape is not known." )
107+ if x_shape .dims != tuple (shape .const_value .numpy ().tolist ()):
108+ return check_result .fail (
109+ f"Input shape { x_shape .dims } does not match the shape { shape .const_value .numpy ().tolist ()} ."
110+ )
111+ return check_result
112+
113+
114+ class ReshapeReshape (orp .RewriteRuleClassBase ):
115+ """Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
116+ The pattern matches only if second reshape reshapes into a shape
117+ with positive values.
118+ """
119+
120+ def pattern (self , op , x , shape_ignored , shape ):
121+ return op .Reshape (op .Reshape (x , shape_ignored ), shape )
122+
123+ def rewrite (self , op , x : ir .Value , shape_ignored : ir .Value , shape : ir .Value ):
124+ return op .Reshape (x , shape )
125+
126+ def check (self , context , x , shape_ignored , shape ) -> orp .MatchResult :
127+ check_result = orp .MatchResult ()
128+ if shape_ignored .const_value is None :
129+ return check_result .fail ("Shape ignored is not a constant." )
130+ if shape .const_value is None :
131+ return check_result .fail ("Shape is not a constant." )
132+ if shape .const_value .numpy ().min () <= 0 :
133+ return check_result .fail ("Shape has non-positive values." )
134+ return check_result
135+
136+
137+ class SlicesSplit (orp .RewriteRuleClassBase ):
138+ """Replaces ``Slice(x, ...), Slice(x, ...)``
139+ by ``Split(x, ...)`` if possible.
140+ """
141+
142+ def pattern (self , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
143+ return op .Slice (x , begin0 , end0 , axes0 ), op .Slice (x , begin1 , end1 , axes1 )
144+
145+ def check (self , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
146+ check_result = orp .MatchResult ()
147+ if (
148+ axes0 .const_value is None
149+ or axes1 .const_value is None
150+ or axes0 .const_value .numpy ().tolist () != axes1 .const_value .numpy ().tolist ()
151+ ):
152+ return check_result .fail ("Axes are not equal or not constant." )
153+ axes = axes0 .const_value .numpy ().tolist ()
154+ if len (axes ) != 1 :
155+ return check_result .fail ("Axes has more than one dimension." )
156+ if x .shape :
157+ rk = len (x .shape )
158+ else :
159+ rk = x .rank
160+ if axes [0 ] != - 1 and axes [0 ] != rk - 1 :
161+ return check_result .fail ("Axes is not -1 or last dimension." )
162+ if (
163+ begin0 .const_value is None
164+ or end0 .const_value is None
165+ or begin1 .const_value is None
166+ or end1 .const_value is None
167+ ):
168+ return check_result .fail ("Begin or end are not constant values." )
169+ if begin0 .const_value .numpy ().tolist () != [0 ]:
170+ return check_result .fail ("First begin value is not 0." )
171+ e0 , b1 , e1 = (
172+ end0 .const_value .numpy ().tolist (),
173+ begin1 .const_value .numpy ().tolist (),
174+ end1 .const_value .numpy ().tolist (),
175+ )
176+ if e0 [0 ] != b1 [0 ]:
177+ return check_result .fail ("End0 is not equal to Begin1." )
178+ shape = x .shape
179+ if shape is None :
180+ return check_result .fail ("Shape is not known." )
181+ last_dim = shape [- 1 ]
182+ if not isinstance (last_dim , int ):
183+ return check_result .fail ("Last dimension is not known." )
184+ if last_dim != e1 [0 ]:
185+ return check_result .fail ("Last dimension is not equal to End1." )
186+ if last_dim // 2 != b1 [0 ]:
187+ return check_result .fail ("Last dimension is not equal to Begin1." )
188+ return check_result
189+
190+ def rewrite (self , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
191+ return op .Split (x , num_outputs = 2 , axis = - 1 , _outputs = 2 )
192+
193+
194+ class TransposeIdentity (orp .RewriteRuleClassBase ):
195+ """Replaces ``Transpose(. perm=perm)``
196+ when the permutation is identity.
197+ """
198+
199+ def pattern (self , op , x , perm ):
200+ return op .Transpose (x , perm = perm )
201+
202+ def check (self , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
203+ check_result = orp .MatchResult ()
204+ if perm .is_ref ():
205+ return check_result .fail ("Permutation is a reference attribute." )
206+ if perm .type == ir .AttributeType .INTS :
207+ perm_ints = perm .as_ints ()
208+ if perm_ints == list (range (len (perm_ints ))):
209+ return check_result
210+ return check_result .fail ("Permutation is not identity." )
211+
212+ def rewrite (self , op , x : ir .Value , perm : ir .Attr ):
213+ return op .Identity (x )
214+
215+
216+ class TransposeTranspose (orp .RewriteRuleClassBase ):
217+ """Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
218+ when both permutations are inverse.
219+ """
220+
221+ def pattern (self , op , x , perm1 , perm2 ):
222+ return op .Transpose (op .Transpose (x , perm = perm1 ), perm = perm2 )
223+
224+ def check (self , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
225+ check_result = orp .MatchResult ()
226+ if perm1 .is_ref () or perm2 .is_ref ():
227+ return check_result .fail ("Permutation is a reference attribute." )
228+ return check_result
229+
230+ def _apply_transpose (self , perm : Sequence [int ], on : list [int ]) -> list [int ]:
231+ assert len (perm ) == len (on ), "length mismatch"
232+ res = [- 1 for i in on ]
233+ for i , p in enumerate (perm ):
234+ res [i ] = on [p ]
235+ return res
236+
237+ def _apply_transposes (
238+ self , perms : list [Sequence [int ]], on : list [int ] | None = None
239+ ) -> list [int ]:
240+ if on is None :
241+ on = list (range (len (perms [0 ])))
242+ for p in perms :
243+ on = self ._apply_transpose (p , on )
244+ return on
245+
246+ def rewrite (self , op , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ):
247+ first = list (range (len (perm1 .as_ints ())))
248+ last = self ._apply_transposes ([perm1 .as_ints (), perm2 .as_ints ()])
249+ if first == last :
250+ return op .Identity (x )
251+ return op .Transpose (x , perm = last )
252+
253+
254+ class UnsqueezeUnsqueeze (orp .RewriteRuleClassBase ):
255+ """Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
256+
257+ def pattern (self , op , x , axes1 , axes2 ):
258+ return op .Unsqueeze (op .Unsqueeze (x , axes1 ), axes2 )
259+
260+ def rewrite (self , op , x : ir .Value , axes1 : ir .Value , axes2 : ir .Value ):
261+ v1 = ir_utils .get_singleton_value (axes1 )
262+ v2 = ir_utils .get_singleton_value (axes2 )
263+ axes = [v1 , v2 ] if v1 < v2 else [v2 , v1 + 1 ]
264+ return op .Unsqueeze (x , op .Constant (value = ir .tensor (axes , dtype = ir .DataType .INT64 )))
265+
266+ def check (self , context , x , axes1 , axes2 ) -> orp .MatchResult :
267+ check_result = orp .MatchResult ()
268+ del context # Unused
269+ del x # Unused
270+ # Currently restricted to single element positive axis
271+ v1 = ir_utils .get_singleton_value (axes1 )
272+ v2 = ir_utils .get_singleton_value (axes2 )
273+ if v1 is None or v2 is None :
274+ return check_result .fail ("Axes are not constant." )
275+ if (v1 < 0 ) or (v2 < 0 ):
276+ return check_result .fail ("Axes are negative." )
277+ return check_result
278+
279+
280+ # Create rule instances
281+ cast_cast_rule = CastCast .rule ()
282+ cast_identity_rule = CastIdentity .rule ()
283+ expand_identity_rule = ExpandIdentity .rule ()
284+ reshape_reshape_rule = ReshapeReshape .rule ()
285+ slice_split_rule = SlicesSplit .rule ()
286+ transpose_identity_rule = TransposeIdentity .rule ()
287+ transpose_transpose_rule = TransposeTranspose .rule ()
288+ unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze .rule ()
289+ squeeze_reshape_1d_rule = SqueezeReshape .rule ()
290+
291+
292+ def basic_optimization_rules () -> orp .RewriteRuleSet :
293+ """Returns a set of basic optimization rules.
294+
295+ These rules perform fundamental optimizations such as:
296+ - Eliminating redundant cast operations
297+ - Simplifying consecutive operations of the same type
298+ - Removing identity operations
299+ - Optimizing shape manipulation operations
300+
301+ These rules are generally safe to apply as a first optimization pass
302+ before other more specialized optimizations.
303+
304+ Returns:
305+ RewriteRuleSet: A collection of basic optimization rules
306+ """
307+ return orp .RewriteRuleSet (
308+ [
309+ cast_cast_rule ,
310+ cast_identity_rule ,
311+ expand_identity_rule ,
312+ reshape_reshape_rule ,
313+ slice_split_rule ,
314+ transpose_identity_rule ,
315+ transpose_transpose_rule ,
316+ unsqueeze_unsqueeze_rule ,
317+ squeeze_reshape_1d_rule ,
318+ ]
319+ )
0 commit comments