22# Licensed under the MIT License.
33from __future__ import annotations
44
5- from typing import ClassVar
5+ from typing import ClassVar , Sequence
66
77from onnxscript import ir
88from onnxscript .rewriter import _ir_utils as ir_utils
@@ -32,26 +32,23 @@ def check(self, context, x) -> orp.MatchResult:
3232 return check_result
3333
3434
35- class CastIdentity (orp .RewriteRuleAsClass ):
35+ class CastIdentity (orp .RewriteRuleClassBase ):
3636 """Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
3737
38- @classmethod
39- def pattern (cls , op , x , to ):
38+ def pattern (self , op , x , to ):
4039 return op .Cast (x , to = to )
4140
42- @classmethod
43- def rewrite (cls , op , x : ir .Value , to : ir .Attr ):
41+ def rewrite (self , op , x : ir .Value , to : ir .Attr ):
4442 return op .Identity (x )
4543
46- @classmethod
47- def check (cls , context , x , to ) -> orp .MatchResult :
44+ def check (self , context , x , to ) -> orp .MatchResult :
4845 check_result = orp .MatchResult ()
49- if x .dtype != to .value :
46+ if x .dtype != to .as_int () :
5047 return check_result .fail ("Input and output types are not the same" )
5148 return check_result
5249
5350
54- class CastCast (orp .RewriteRuleAsClass ):
51+ class CastCast (orp .RewriteRuleClassBase ):
5552 """Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
5653
5754 _allowed_tensor_types : ClassVar = {
@@ -61,37 +58,31 @@ class CastCast(orp.RewriteRuleAsClass):
6158 ir .DataType .DOUBLE ,
6259 }
6360
64- @classmethod
65- def pattern (cls , op , x , to , to_ignored ):
61+ def pattern (self , op , x , to , to_ignored ):
6662 return op .Cast (op .Cast (x , to = to_ignored ), to = to )
6763
68- @classmethod
69- def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
64+ def check (self , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
7065 check_result = orp .MatchResult ()
71- if to .value not in cls ._allowed_tensor_types :
72- return check_result .fail (f"Output type { to .value } is not allowed" )
73- if to_ignored .as_int () not in cls ._allowed_tensor_types :
74- return check_result .fail (f"Ignored type { to_ignored .value } is not allowed" )
66+ if to .as_int () not in self ._allowed_tensor_types :
67+ return check_result .fail (f"Output type { to .as_int () } is not allowed" )
68+ if to_ignored .as_int () not in self ._allowed_tensor_types :
69+ return check_result .fail (f"Ignored type { to_ignored .as_int () } is not allowed" )
7570 return check_result
7671
77- @classmethod
78- def rewrite (cls , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
72+ def rewrite (self , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
7973 return op .Cast (x , to = to )
8074
8175
82- class ExpandIdentity (orp .RewriteRuleAsClass ):
76+ class ExpandIdentity (orp .RewriteRuleClassBase ):
8377 """Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
8478
85- @classmethod
86- def pattern (cls , op , x , shape ):
79+ def pattern (self , op , x , shape ):
8780 return op .Expand (x , shape )
8881
89- @classmethod
90- def rewrite (cls , op , x : ir .Value , shape : ir .Value ):
82+ def rewrite (self , op , x : ir .Value , shape : ir .Value ):
9183 return op .Identity (x )
9284
93- @classmethod
94- def check (cls , context , x , shape ) -> orp .MatchResult :
85+ def check (self , context , x , shape ) -> orp .MatchResult :
9586 check_result = orp .MatchResult ()
9687 if shape .const_value is None :
9788 # Shape is not a constant and cannot be guessed.
@@ -106,22 +97,19 @@ def check(cls, context, x, shape) -> orp.MatchResult:
10697 return check_result
10798
10899
109- class ReshapeReshape (orp .RewriteRuleAsClass ):
100+ class ReshapeReshape (orp .RewriteRuleClassBase ):
110101 """Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
111102 The pattern matches only if second reshape reshapes into a shape
112103 with positive values.
113104 """
114105
115- @classmethod
116- def pattern (cls , op , x , shape_ignored , shape ):
106+ def pattern (self , op , x , shape_ignored , shape ):
117107 return op .Reshape (op .Reshape (x , shape_ignored ), shape )
118108
119- @classmethod
120- def rewrite (cls , op , x : ir .Value , shape_ignored : ir .Value , shape : ir .Value ):
109+ def rewrite (self , op , x : ir .Value , shape_ignored : ir .Value , shape : ir .Value ):
121110 return op .Reshape (x , shape )
122111
123- @classmethod
124- def check (cls , context , x , shape_ignored , shape ) -> orp .MatchResult :
112+ def check (self , context , x , shape_ignored , shape ) -> orp .MatchResult :
125113 check_result = orp .MatchResult ()
126114 if shape_ignored .const_value is None :
127115 return check_result .fail ("Shape ignored is not a constant." )
@@ -132,17 +120,15 @@ def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
132120 return check_result
133121
134122
135- class SlicesSplit (orp .RewriteRuleAsClass ):
123+ class SlicesSplit (orp .RewriteRuleClassBase ):
136124 """Replaces ``Slice(x, ...), Slice(x, ...)``
137125 by ``Split(x, ...)`` if possible.
138126 """
139127
140- @classmethod
141- def pattern (cls , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
128+ def pattern (self , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
142129 return op .Slice (x , begin0 , end0 , axes0 ), op .Slice (x , begin1 , end1 , axes1 )
143130
144- @classmethod
145- def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
131+ def check (self , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
146132 check_result = orp .MatchResult ()
147133 if (
148134 axes0 .const_value is None
@@ -187,94 +173,83 @@ def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.Matc
187173 return check_result .fail ("Last dimension is not equal to Begin1." )
188174 return check_result
189175
190- @classmethod
191- def rewrite (cls , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
176+ def rewrite (self , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
192177 return op .Split (x , num_outputs = 2 , axis = - 1 , _outputs = 2 )
193178
194179
195- class TransposeIdentity (orp .RewriteRuleAsClass ):
180+ class TransposeIdentity (orp .RewriteRuleClassBase ):
196181 """Replaces ``Transpose(. perm=perm)``
197182 when the permutation is identity.
198183 """
199184
200- @classmethod
201- def pattern (cls , op , x , perm ):
185+ def pattern (self , op , x , perm ):
202186 return op .Transpose (x , perm = perm )
203187
204- @classmethod
205- def check (cls , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
188+ def check (self , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
206189 check_result = orp .MatchResult ()
207190 if isinstance (perm , ir .RefAttr ):
208191 return check_result .fail ("Permutation is a reference attribute." )
209192 if perm .type == ir .AttributeType .INTS :
210- if perm .value == list (range (len (perm .value ))):
193+ perm_ints = perm .as_ints ()
194+ if perm_ints == list (range (len (perm_ints ))):
211195 return check_result
212196 return check_result .fail ("Permutation is not identity." )
213197
214- @classmethod
215- def rewrite (cls , op , x : ir .Value , perm : ir .Attr ):
198+ def rewrite (self , op , x : ir .Value , perm : ir .Attr ):
216199 return op .Identity (x )
217200
218201
219- class TransposeTranspose (orp .RewriteRuleAsClass ):
202+ class TransposeTranspose (orp .RewriteRuleClassBase ):
220203 """Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
221204 when both permutations are inverse.
222205 """
223206
224- @classmethod
225- def pattern (cls , op , x , perm1 , perm2 ):
207+ def pattern (self , op , x , perm1 , perm2 ):
226208 return op .Transpose (op .Transpose (x , perm = perm1 ), perm = perm2 )
227209
228- @classmethod
229- def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
210+ def check (self , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
230211 check_result = orp .MatchResult ()
231212 if isinstance (perm1 , ir .RefAttr ) or isinstance (perm2 , ir .RefAttr ):
232213 return check_result .fail ("Permutation is a reference attribute." )
233214 return check_result
234215
235- @classmethod
236- def _apply_transpose (cls , perm : tuple [int , ...], on : list [int ]) -> list [int ]:
216+ def _apply_transpose (self , perm : Sequence [int ], on : list [int ]) -> list [int ]:
237217 assert len (perm ) == len (on ), "length mismatch"
238218 res = [- 1 for i in on ]
239219 for i , p in enumerate (perm ):
240220 res [i ] = on [p ]
241221 return res
242222
243- @classmethod
244223 def _apply_transposes (
245- cls , perms : list [tuple [int , ... ]], on : list [int ] | None = None
224+ self , perms : list [Sequence [int ]], on : list [int ] | None = None
246225 ) -> list [int ]:
247226 if on is None :
248227 on = list (range (len (perms [0 ])))
249228 for p in perms :
250- on = cls ._apply_transpose (p , on )
229+ on = self ._apply_transpose (p , on )
251230 return on
252231
253- @classmethod
254- def rewrite (cls , op , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ):
255- first = list (range (len (perm1 .value )))
256- last = cls ._apply_transposes ([perm1 .value , perm2 .value ])
232+ def rewrite (self , op , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ):
233+ first = list (range (len (perm1 .as_ints ())))
234+ last = self ._apply_transposes ([perm1 .as_ints (), perm2 .as_ints ()])
257235 if first == last :
258236 return op .Identity (x )
259237 return op .Transpose (x , perm = last )
260238
261239
262- class UnsqueezeUnsqueeze (orp .RewriteRuleAsClass ):
240+ class UnsqueezeUnsqueeze (orp .RewriteRuleClassBase ):
263241 """Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
264242
265- @classmethod
266- def pattern (cls , op , x , axes1 , axes2 ):
243+ def pattern (self , op , x , axes1 , axes2 ):
267244 return op .Unsqueeze (op .Unsqueeze (x , axes1 ), axes2 )
268245
269- @classmethod
270- def rewrite (cls , op , x : ir .Value , axes1 : ir .Value , axes2 : ir .Value ):
246+ def rewrite (self , op , x : ir .Value , axes1 : ir .Value , axes2 : ir .Value ):
271247 v1 = ir_utils .get_singleton_value (axes1 )
272248 v2 = ir_utils .get_singleton_value (axes2 )
273249 axes = [v1 , v2 ] if v1 < v2 else [v2 , v1 + 1 ]
274250 return op .Unsqueeze (x , op .Constant (value = ir .tensor (axes , dtype = ir .DataType .INT64 )))
275251
276- @classmethod
277- def check (cls , context , x , axes1 , axes2 ) -> orp .MatchResult :
252+ def check (self , context , x , axes1 , axes2 ) -> orp .MatchResult :
278253 check_result = orp .MatchResult ()
279254 del context # Unused
280255 del x # Unused
@@ -288,14 +263,14 @@ def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
288263 return check_result
289264
290265
291- cast_cast_rule = orp . make_rewrite_rule_from_class ( CastCast )
292- cast_identity_rule = orp . make_rewrite_rule_from_class ( CastIdentity )
293- expand_identity_rule = orp . make_rewrite_rule_from_class ( ExpandIdentity )
294- reshape_reshape_rule = orp . make_rewrite_rule_from_class ( ReshapeReshape )
295- slice_split_rule = orp . make_rewrite_rule_from_class ( SlicesSplit , True )
296- transpose_identity_rule = orp . make_rewrite_rule_from_class ( TransposeIdentity )
297- transpose_transpose_rule = orp . make_rewrite_rule_from_class ( TransposeTranspose )
298- unsqueeze_unsqueeze_rule = orp . make_rewrite_rule_from_class ( UnsqueezeUnsqueeze )
266+ cast_cast_rule = CastCast . rule ( )
267+ cast_identity_rule = CastIdentity . rule ( )
268+ expand_identity_rule = ExpandIdentity . rule ( )
269+ reshape_reshape_rule = ReshapeReshape . rule ( )
270+ slice_split_rule = SlicesSplit . rule ( )
271+ transpose_identity_rule = TransposeIdentity . rule ( )
272+ transpose_transpose_rule = TransposeTranspose . rule ( )
273+ unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze . rule ( )
299274squeeze_reshape_1d_rule = SqueezeReshape .rule ()
300275
301276
0 commit comments