Skip to content

Commit 2e89a2c

Browse files
authored
[pass] Create topological sort pass (#2191)
Simply expose the `sort()` api as a pass for composability.
1 parent 312219b commit 2e89a2c

2 files changed

Lines changed: 83 additions & 0 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Pass for topologically sorting the graphs."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"TopologicalSortPass",
9+
]
10+
11+
12+
from onnxscript import ir
13+
14+
15+
class TopologicalSortPass(ir.passes.InPlacePass):
16+
"""Topologically sort graphs and functions in a model."""
17+
18+
def call(self, model: ir.Model) -> ir.passes.PassResult:
19+
original_nodes = list(model.graph)
20+
model.graph.sort()
21+
sorted_nodes = list(model.graph)
22+
for function in model.functions.values():
23+
original_nodes.extend(function)
24+
function.sort()
25+
sorted_nodes.extend(function)
26+
27+
# Compare node orders to determine if any changes were made
28+
modified = False
29+
for node, new_node in zip(original_nodes, sorted_nodes):
30+
if node is not new_node:
31+
modified = True
32+
break
33+
return ir.passes.PassResult(model=model, modified=modified)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Unit tests for the TopologicalSortPass."""
4+
5+
import unittest
6+
7+
from onnxscript import ir
8+
from onnxscript.ir.passes.common import topological_sort
9+
10+
11+
class TopologicalSortPassTest(unittest.TestCase):
12+
def setUp(self):
13+
self.node_a = ir.node("A", inputs=[], name="node_a")
14+
self.node_b = ir.node("B", inputs=self.node_a.outputs, name="node_b")
15+
self.node_c = ir.node("C", inputs=self.node_b.outputs, name="node_c")
16+
17+
def test_topological_sort_modified_true(self):
18+
graph = ir.Graph(
19+
inputs=self.node_a.inputs,
20+
outputs=self.node_c.outputs,
21+
nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes
22+
name="test_graph",
23+
)
24+
model = ir.Model(graph, ir_version=10)
25+
result = topological_sort.TopologicalSortPass()(model)
26+
self.assertTrue(result.modified)
27+
self.assertEqual(
28+
tuple(result.model.graph),
29+
(self.node_a, self.node_b, self.node_c),
30+
)
31+
32+
def test_topological_sort_modified_false(self):
33+
"""Test that modified is False when the input model is already sorted."""
34+
sorted_graph = ir.Graph(
35+
inputs=self.node_a.inputs,
36+
outputs=self.node_c.outputs,
37+
nodes=[self.node_a, self.node_b, self.node_c], # Sorted nodes
38+
name="test_graph",
39+
)
40+
sorted_model = ir.Model(sorted_graph, ir_version=10)
41+
result = topological_sort.TopologicalSortPass()(sorted_model)
42+
self.assertFalse(result.modified)
43+
self.assertEqual(
44+
tuple(result.model.graph),
45+
(self.node_a, self.node_b, self.node_c),
46+
)
47+
48+
49+
if __name__ == "__main__":
50+
unittest.main()

0 commit comments

Comments
 (0)