Skip to content

Commit feb20f1

Browse files
authored
[IR] Allow pass result as pass input (#2220)
Allow pass result as pass input so users can chain calls to multiple passes more easily Before: ```py result = pass1(model) result = pass(result.model) ``` Now it is also possible to do: ```py result = pass1(model) result = pass(result) ```
1 parent b0a4401 commit feb20f1

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

onnxscript/ir/passes/_pass_infra.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ def destructive(self) -> bool:
108108
"""
109109
return not self.in_place and self.changes_input
110110

111-
def __call__(self, model: ir.Model) -> PassResult:
111+
def __call__(self, model_or_result: ir.Model | PassResult, /) -> PassResult:
112+
if isinstance(model_or_result, PassResult):
113+
model = model_or_result.model
114+
else:
115+
model = model_or_result
112116
# Check preconditions
113117
try:
114118
self.requires(model)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from __future__ import annotations
5+
6+
import unittest
7+
8+
from onnxscript import ir
9+
from onnxscript.ir.passes import _pass_infra
10+
11+
12+
class PassBaseTest(unittest.TestCase):
13+
def test_pass_results_can_be_used_as_pass_input(self):
14+
class TestPass(_pass_infra.PassBase):
15+
@property
16+
def in_place(self) -> bool:
17+
return True
18+
19+
@property
20+
def changes_input(self) -> bool:
21+
return False
22+
23+
def call(self, model: ir.Model) -> _pass_infra.PassResult:
24+
# This is a no-op pass
25+
return _pass_infra.PassResult(model=model, modified=False)
26+
27+
pass_ = TestPass()
28+
model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
29+
result = pass_(model)
30+
self.assertIsInstance(result, _pass_infra.PassResult)
31+
# pass can take the result of another pass as input
32+
result_1 = pass_(result)
33+
# It can also take the model as input
34+
result_2 = pass_(result.model)
35+
self.assertIs(result_1.model, result_2.model)
36+
37+
38+
if __name__ == "__main__":
39+
unittest.main()

0 commit comments

Comments
 (0)