Skip to content

Commit 5453385

Browse files
authored
Merge branch 'main' into justinchu/node-graph
2 parents 24b7ecd + 9d16b89 commit 5453385

7 files changed

Lines changed: 170 additions & 11 deletions

File tree

docs/tutorial/rewriter/rewrite_patterns.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,15 @@ In order to apply this method to the example above, first create the two separat
152152
:pyobject: erf_gelu_pattern_2
153153
```
154154

155+
:::{note}
156+
:name: rule-application-order-matters
157+
158+
When you pass multiple rules in `pattern_rewrite_rules`, the **order in which they appear is important**.
159+
This is because some rules may depend on patterns created or modified by earlier rules. For example, if `rule2` can only match after `rule1` has made a specific change in the model, then `rule1` must come **before** `rule2` in the list.
160+
If you're not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur.
161+
:::
162+
163+
155164
Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter.
156165

157166
```{literalinclude} examples/erfgelu.py

onnxscript/ir/passes/_pass_infra.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,20 @@ def __call__(self, model: ir.Model) -> PassResult:
136136
f"The result of the pass '{self.__class__.__name__}' should be type PassResult. "
137137
"Please create one with ir.passes.PassResult()."
138138
)
139+
140+
# Checks that the declared in-place property is respected
141+
if self.in_place and result.model is not model:
142+
raise PassError(
143+
f"The pass '{self.__class__.__name__}' is declared in-place, "
144+
"but the model returned is *not* the same object as the input model. "
145+
"Pass developer: Pass should return the same model object or the in_place property should return False."
146+
)
147+
if not self.in_place and result.model is model:
148+
raise PassError(
149+
f"The pass '{self.__class__.__name__}' is declared not in-place, "
150+
"but the model returned *is* the same object as the input model. "
151+
"Pass developer: Pass should return a new model object or the in_place property should return True."
152+
)
139153
return result
140154

141155
@abc.abstractmethod

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
2323
Attributes:
2424
lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
2525
Default to False, where only Constants with the ``value`` attribute are lifted.
26+
size_limit: The minimum size of the tensor to be lifted. If the tensor contains
27+
number of elements less than size_limit, it will not be lifted. Default is 16.
2628
"""
2729

28-
def __init__(self, lift_all_constants: bool = False):
30+
def __init__(self, lift_all_constants: bool = False, size_limit: int = 16):
2931
super().__init__()
30-
self._lift_all_constants = lift_all_constants
32+
self.lift_all_constants = lift_all_constants
33+
self.size_limit = size_limit
3134

3235
def call(self, model: ir.Model) -> ir.passes.PassResult:
3336
count = 0
@@ -79,16 +82,17 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
7982

8083
def _constant_node_attribute_to_tensor(
8184
self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
82-
) -> ir.Tensor | None:
85+
) -> ir.TensorProtocol | None:
8386
"""Convert constant node attribute to tensor."""
84-
if not self._lift_all_constants and attr_name != "value":
87+
if not self.lift_all_constants and attr_name != "value":
8588
logger.debug(
8689
"Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
8790
)
8891
return None
8992

93+
tensor: ir.TensorProtocol
9094
if attr_name == "value":
91-
tensor = attr_value.as_tensor() # type: ignore[union-attr]
95+
tensor = attr_value.as_tensor()
9296
elif attr_name == "value_int":
9397
tensor = ir.tensor(
9498
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
@@ -110,5 +114,15 @@ def _constant_node_attribute_to_tensor(
110114
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
111115
)
112116
else:
113-
tensor = None
114-
return tensor # type: ignore[return-value]
117+
raise ValueError(
118+
f"Unsupported constant node '{node.name}' attribute '{attr_name}'"
119+
)
120+
121+
if tensor.size < self.size_limit:
122+
logger.debug(
123+
"Tensor from node '%s' has less than %s elements",
124+
node.name,
125+
self.size_limit,
126+
)
127+
return None
128+
return tensor

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
5656

5757
# Perform lift constants to initializers
5858
result = constant_manipulation.LiftConstantsToInitializersPass(
59-
lift_all_constants=lift_all_constants
59+
lift_all_constants=lift_all_constants, size_limit=0
6060
)(model)
6161
self.assertTrue(result.modified)
6262
# Check that the constant node is lifted to an initializer
@@ -130,7 +130,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
130130
ir_version=10,
131131
)
132132
result = constant_manipulation.LiftConstantsToInitializersPass(
133-
lift_all_constants=lift_all_constants
133+
lift_all_constants=lift_all_constants, size_limit=0
134134
)(model)
135135
self.assertTrue(result.modified)
136136
# Check that the constant node is lifted to the subgraph initializers
@@ -206,7 +206,7 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
206206

207207
# Perform lift constants to initializers
208208
result = constant_manipulation.LiftConstantsToInitializersPass(
209-
lift_all_constants=lift_all_constants
209+
lift_all_constants=lift_all_constants, size_limit=0
210210
)(model)
211211
if lift_all_constants:
212212
self.assertTrue(result.modified)
@@ -249,3 +249,7 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self):
249249
self.assertFalse(result.modified)
250250
# Check that the constant node is not lifted to an initializer
251251
self.assertEqual(len(result.model.graph.initializers), 0)
252+
253+
254+
if __name__ == "__main__":
255+
unittest.main()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Pass for topologically sorting the graphs."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"TopologicalSortPass",
9+
]
10+
11+
12+
from onnxscript import ir
13+
14+
15+
class TopologicalSortPass(ir.passes.InPlacePass):
16+
"""Topologically sort graphs and functions in a model."""
17+
18+
def call(self, model: ir.Model) -> ir.passes.PassResult:
19+
original_nodes = list(model.graph)
20+
model.graph.sort()
21+
sorted_nodes = list(model.graph)
22+
for function in model.functions.values():
23+
original_nodes.extend(function)
24+
function.sort()
25+
sorted_nodes.extend(function)
26+
27+
# Compare node orders to determine if any changes were made
28+
modified = False
29+
for node, new_node in zip(original_nodes, sorted_nodes):
30+
if node is not new_node:
31+
modified = True
32+
break
33+
return ir.passes.PassResult(model=model, modified=modified)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Unit tests for the TopologicalSortPass."""
4+
5+
import unittest
6+
7+
from onnxscript import ir
8+
from onnxscript.ir.passes.common import topological_sort
9+
10+
11+
class TopologicalSortPassTest(unittest.TestCase):
12+
def setUp(self):
13+
self.node_a = ir.node("A", inputs=[], name="node_a")
14+
self.node_b = ir.node("B", inputs=self.node_a.outputs, name="node_b")
15+
self.node_c = ir.node("C", inputs=self.node_b.outputs, name="node_c")
16+
17+
def test_topological_sort_modified_true(self):
18+
graph = ir.Graph(
19+
inputs=self.node_a.inputs,
20+
outputs=self.node_c.outputs,
21+
nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes
22+
name="test_graph",
23+
)
24+
model = ir.Model(graph, ir_version=10)
25+
result = topological_sort.TopologicalSortPass()(model)
26+
self.assertTrue(result.modified)
27+
self.assertEqual(
28+
tuple(result.model.graph),
29+
(self.node_a, self.node_b, self.node_c),
30+
)
31+
32+
def test_topological_sort_modified_false(self):
33+
"""Test that modified is False when the input model is already sorted."""
34+
sorted_graph = ir.Graph(
35+
inputs=self.node_a.inputs,
36+
outputs=self.node_c.outputs,
37+
nodes=[self.node_a, self.node_b, self.node_c], # Sorted nodes
38+
name="test_graph",
39+
)
40+
sorted_model = ir.Model(sorted_graph, ir_version=10)
41+
result = topological_sort.TopologicalSortPass()(sorted_model)
42+
self.assertFalse(result.modified)
43+
self.assertEqual(
44+
tuple(result.model.graph),
45+
(self.node_a, self.node_b, self.node_c),
46+
)
47+
48+
def test_topological_sort_on_functions(self):
49+
"""Test that TopologicalSortPass works on functions in a model."""
50+
# Create a function with unsorted nodes
51+
func_graph = ir.Graph(
52+
inputs=self.node_a.inputs,
53+
outputs=self.node_c.outputs,
54+
nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes
55+
)
56+
function = ir.Function(
57+
domain="test_domain",
58+
name="test_function",
59+
graph=func_graph,
60+
attributes=[],
61+
)
62+
63+
# Create a model with the function
64+
graph = ir.Graph(
65+
inputs=[],
66+
outputs=[],
67+
nodes=[],
68+
name="test_graph",
69+
)
70+
model = ir.Model(graph, ir_version=10, functions=[function])
71+
72+
# Apply the TopologicalSortPass
73+
result = topological_sort.TopologicalSortPass()(model)
74+
75+
# Verify that the nodes in the function are sorted
76+
sorted_func_nodes = (self.node_a, self.node_b, self.node_c)
77+
self.assertTrue(result.modified)
78+
self.assertEqual(
79+
tuple(result.model.functions[function.identifier()]),
80+
sorted_func_nodes,
81+
)
82+
83+
84+
if __name__ == "__main__":
85+
unittest.main()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview
22
--index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/
3-
onnxruntime==1.22.0.dev20250303002
3+
onnxruntime==1.22.0.dev20250402004

0 commit comments

Comments
 (0)