Skip to content

Commit 7ccdbfb

Browse files
Copilotjustinchuby
andcommitted
Create basic_rules.py and update rewriter to use it
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 74bf271 commit 7ccdbfb

4 files changed

Lines changed: 827 additions & 289 deletions

File tree

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
import onnxscript.ir.passes.common as common_passes
1616
from onnxscript import ir
1717
from onnxscript.rewriter import (
18+
basic_rules,
1819
broadcast_to_matmul,
1920
cast_constant_of_shape,
2021
collapse_slices,
2122
gemm_to_matmul_add,
22-
llama_rule_sets,
2323
no_op,
2424
pattern,
2525
)
@@ -31,7 +31,7 @@
3131
gemm_to_matmul_add.rule, # type: ignore[has-type]
3232
*cast_constant_of_shape.rules.rules,
3333
*collapse_slices.rules.rules,
34-
*llama_rule_sets.llama_p0_rule_set().rules,
34+
*basic_rules.basic_optimization_rules().rules,
3535
)
3636

3737

onnxscript/rewriter/basic_rules.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Basic rewrite rules for general optimization patterns.
4+
5+
This module contains fundamental optimization rules that are generally applicable
6+
to most ONNX models, including cast elimination, transpose simplification,
7+
shape operation fusion, and other common patterns.
8+
"""
9+
from __future__ import annotations
10+
11+
from typing import ClassVar, Sequence
12+
13+
from onnxscript import ir
14+
from onnxscript.rewriter import _ir_utils as ir_utils
15+
from onnxscript.rewriter import pattern as orp
16+
17+
18+
class SqueezeReshape(orp.RewriteRuleClassBase):
19+
"""Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
20+
21+
This pattern arises from the translation of pytorch symints.
22+
"""
23+
24+
def __init__(self):
25+
super().__init__("SqueezeReshape1d", remove_nodes=False)
26+
27+
def pattern(self, op, x):
28+
return op.Reshape(op.Squeeze(x), [-1])
29+
30+
def rewrite(self, op, x: ir.Value):
31+
return op.Identity(x)
32+
33+
def check(self, context, x) -> orp.MatchResult:
34+
del context # Unused
35+
check_result = orp.MatchResult()
36+
if not ir_utils.has_rank(x, 1):
37+
return check_result.fail("Input is not 1D")
38+
return check_result
39+
40+
41+
class CastIdentity(orp.RewriteRuleClassBase):
42+
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
43+
44+
def pattern(self, op, x, to):
45+
return op.Cast(x, to=to)
46+
47+
def rewrite(self, op, x: ir.Value, to: ir.Attr):
48+
return op.Identity(x)
49+
50+
def check(self, context, x, to) -> orp.MatchResult:
51+
check_result = orp.MatchResult()
52+
if x.dtype != to.as_int():
53+
return check_result.fail("Input and output types are not the same")
54+
return check_result
55+
56+
57+
class CastCast(orp.RewriteRuleClassBase):
58+
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
59+
60+
# Simplify "cast type1 => type2 => type3" to "cast type1 => type3".
61+
# This rule is not valid for all combinations of types: e.g.,
62+
# it is not valid for float32 => float16 => float32 or float32 => int32 => string.
63+
# TODO: fill out the list of allowed combinations: the following is just a couple
64+
# that shows up in practice where it is valid
65+
_allowed_type2_type3: ClassVar = frozenset(
66+
{
67+
(ir.DataType.FLOAT, ir.DataType.FLOAT16),
68+
(ir.DataType.FLOAT, ir.DataType.BFLOAT16),
69+
}
70+
)
71+
72+
def pattern(self, op, x, to, to_ignored):
73+
return op.Cast(op.Cast(x, to=to_ignored), to=to)
74+
75+
def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
76+
check_result = orp.MatchResult()
77+
type2 = to_ignored.as_int()
78+
type3 = to.as_int()
79+
if (type2, type3) not in self._allowed_type2_type3:
80+
return check_result.fail(
81+
f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. "
82+
f"Cast-Cast rule may be incomplete for this combination."
83+
)
84+
return check_result
85+
86+
def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
87+
return op.Cast(x, to=to)
88+
89+
90+
class ExpandIdentity(orp.RewriteRuleClassBase):
91+
"""Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
92+
93+
def pattern(self, op, x, shape):
94+
return op.Expand(x, shape)
95+
96+
def rewrite(self, op, x: ir.Value, shape: ir.Value):
97+
return op.Identity(x)
98+
99+
def check(self, context, x, shape) -> orp.MatchResult:
100+
check_result = orp.MatchResult()
101+
if shape.const_value is None:
102+
# Shape is not a constant and cannot be guessed.
103+
return check_result.fail("Shape is not a constant and cannot be guessed.")
104+
if (x_shape := x.shape) is None:
105+
# We don't know the shape of the input
106+
return check_result.fail("Input shape is not known.")
107+
if x_shape.dims != tuple(shape.const_value.numpy().tolist()):
108+
return check_result.fail(
109+
f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}."
110+
)
111+
return check_result
112+
113+
114+
class ReshapeReshape(orp.RewriteRuleClassBase):
115+
"""Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
116+
The pattern matches only if second reshape reshapes into a shape
117+
with positive values.
118+
"""
119+
120+
def pattern(self, op, x, shape_ignored, shape):
121+
return op.Reshape(op.Reshape(x, shape_ignored), shape)
122+
123+
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
124+
return op.Reshape(x, shape)
125+
126+
def check(self, context, x, shape_ignored, shape) -> orp.MatchResult:
127+
check_result = orp.MatchResult()
128+
if shape_ignored.const_value is None:
129+
return check_result.fail("Shape ignored is not a constant.")
130+
if shape.const_value is None:
131+
return check_result.fail("Shape is not a constant.")
132+
if shape.const_value.numpy().min() <= 0:
133+
return check_result.fail("Shape has non-positive values.")
134+
return check_result
135+
136+
137+
class SlicesSplit(orp.RewriteRuleClassBase):
138+
"""Replaces ``Slice(x, ...), Slice(x, ...)``
139+
by ``Split(x, ...)`` if possible.
140+
"""
141+
142+
def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
143+
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)
144+
145+
def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
146+
check_result = orp.MatchResult()
147+
if (
148+
axes0.const_value is None
149+
or axes1.const_value is None
150+
or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist()
151+
):
152+
return check_result.fail("Axes are not equal or not constant.")
153+
axes = axes0.const_value.numpy().tolist()
154+
if len(axes) != 1:
155+
return check_result.fail("Axes has more than one dimension.")
156+
if x.shape:
157+
rk = len(x.shape)
158+
else:
159+
rk = x.rank
160+
if axes[0] != -1 and axes[0] != rk - 1:
161+
return check_result.fail("Axes is not -1 or last dimension.")
162+
if (
163+
begin0.const_value is None
164+
or end0.const_value is None
165+
or begin1.const_value is None
166+
or end1.const_value is None
167+
):
168+
return check_result.fail("Begin or end are not constant values.")
169+
if begin0.const_value.numpy().tolist() != [0]:
170+
return check_result.fail("First begin value is not 0.")
171+
e0, b1, e1 = (
172+
end0.const_value.numpy().tolist(),
173+
begin1.const_value.numpy().tolist(),
174+
end1.const_value.numpy().tolist(),
175+
)
176+
if e0[0] != b1[0]:
177+
return check_result.fail("End0 is not equal to Begin1.")
178+
shape = x.shape
179+
if shape is None:
180+
return check_result.fail("Shape is not known.")
181+
last_dim = shape[-1]
182+
if not isinstance(last_dim, int):
183+
return check_result.fail("Last dimension is not known.")
184+
if last_dim != e1[0]:
185+
return check_result.fail("Last dimension is not equal to End1.")
186+
if last_dim // 2 != b1[0]:
187+
return check_result.fail("Last dimension is not equal to Begin1.")
188+
return check_result
189+
190+
def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
191+
return op.Split(x, num_outputs=2, axis=-1, _outputs=2)
192+
193+
194+
class TransposeIdentity(orp.RewriteRuleClassBase):
195+
"""Replaces ``Transpose(. perm=perm)``
196+
when the permutation is identity.
197+
"""
198+
199+
def pattern(self, op, x, perm):
200+
return op.Transpose(x, perm=perm)
201+
202+
def check(self, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
203+
check_result = orp.MatchResult()
204+
if perm.is_ref():
205+
return check_result.fail("Permutation is a reference attribute.")
206+
if perm.type == ir.AttributeType.INTS:
207+
perm_ints = perm.as_ints()
208+
if perm_ints == list(range(len(perm_ints))):
209+
return check_result
210+
return check_result.fail("Permutation is not identity.")
211+
212+
def rewrite(self, op, x: ir.Value, perm: ir.Attr):
213+
return op.Identity(x)
214+
215+
216+
class TransposeTranspose(orp.RewriteRuleClassBase):
217+
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
218+
when both permutations are inverse.
219+
"""
220+
221+
def pattern(self, op, x, perm1, perm2):
222+
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)
223+
224+
def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
225+
check_result = orp.MatchResult()
226+
if perm1.is_ref() or perm2.is_ref():
227+
return check_result.fail("Permutation is a reference attribute.")
228+
return check_result
229+
230+
def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]:
231+
assert len(perm) == len(on), "length mismatch"
232+
res = [-1 for i in on]
233+
for i, p in enumerate(perm):
234+
res[i] = on[p]
235+
return res
236+
237+
def _apply_transposes(
238+
self, perms: list[Sequence[int]], on: list[int] | None = None
239+
) -> list[int]:
240+
if on is None:
241+
on = list(range(len(perms[0])))
242+
for p in perms:
243+
on = self._apply_transpose(p, on)
244+
return on
245+
246+
def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
247+
first = list(range(len(perm1.as_ints())))
248+
last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()])
249+
if first == last:
250+
return op.Identity(x)
251+
return op.Transpose(x, perm=last)
252+
253+
254+
class UnsqueezeUnsqueeze(orp.RewriteRuleClassBase):
255+
"""Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
256+
257+
def pattern(self, op, x, axes1, axes2):
258+
return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2)
259+
260+
def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
261+
v1 = ir_utils.get_singleton_value(axes1)
262+
v2 = ir_utils.get_singleton_value(axes2)
263+
axes = [v1, v2] if v1 < v2 else [v2, v1 + 1]
264+
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))
265+
266+
def check(self, context, x, axes1, axes2) -> orp.MatchResult:
267+
check_result = orp.MatchResult()
268+
del context # Unused
269+
del x # Unused
270+
# Currently restricted to single element positive axis
271+
v1 = ir_utils.get_singleton_value(axes1)
272+
v2 = ir_utils.get_singleton_value(axes2)
273+
if v1 is None or v2 is None:
274+
return check_result.fail("Axes are not constant.")
275+
if (v1 < 0) or (v2 < 0):
276+
return check_result.fail("Axes are negative.")
277+
return check_result
278+
279+
280+
# Create rule instances
281+
cast_cast_rule = CastCast.rule()
282+
cast_identity_rule = CastIdentity.rule()
283+
expand_identity_rule = ExpandIdentity.rule()
284+
reshape_reshape_rule = ReshapeReshape.rule()
285+
slice_split_rule = SlicesSplit.rule()
286+
transpose_identity_rule = TransposeIdentity.rule()
287+
transpose_transpose_rule = TransposeTranspose.rule()
288+
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
289+
squeeze_reshape_1d_rule = SqueezeReshape.rule()
290+
291+
292+
def basic_optimization_rules() -> orp.RewriteRuleSet:
293+
"""Returns a set of basic optimization rules.
294+
295+
These rules perform fundamental optimizations such as:
296+
- Eliminating redundant cast operations
297+
- Simplifying consecutive operations of the same type
298+
- Removing identity operations
299+
- Optimizing shape manipulation operations
300+
301+
These rules are generally safe to apply as a first optimization pass
302+
before other more specialized optimizations.
303+
304+
Returns:
305+
RewriteRuleSet: A collection of basic optimization rules
306+
"""
307+
return orp.RewriteRuleSet(
308+
[
309+
cast_cast_rule,
310+
cast_identity_rule,
311+
expand_identity_rule,
312+
reshape_reshape_rule,
313+
slice_split_rule,
314+
transpose_identity_rule,
315+
transpose_transpose_rule,
316+
unsqueeze_unsqueeze_rule,
317+
squeeze_reshape_1d_rule,
318+
]
319+
)

0 commit comments

Comments
 (0)