Skip to content

Commit 1048faf

Browse files
authored
[passes] Move inliner to common passes (#2206)
Expose inliner to common passes for general usage. Fix #2194
1 parent 397baa1 commit 1048faf

5 files changed

Lines changed: 38 additions & 31 deletions

File tree

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44

55
from __future__ import annotations
66

7+
import dataclasses
8+
9+
__all__ = ["InlinePass", "InlinePassResult"]
10+
711
from collections import defaultdict
812
from typing import Iterable, List, Sequence, Tuple
913

10-
import onnxscript.ir as ir
11-
import onnxscript.ir.convenience as ir_convenience
14+
import onnxscript.ir.convenience as _ir_convenience
15+
from onnxscript import ir
1216

1317
# A replacement for a node specifies a list of nodes that replaces the original node,
1418
# and a list of values that replaces the original node's outputs.
@@ -22,7 +26,7 @@
2226
CallStack = List[CallSiteId]
2327

2428

25-
def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str:
29+
def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
2630
"""Generate a unique name from a name, calling-context, and set of used names.
2731
2832
If there is a name clash, we add a numeric suffix to the name to make
@@ -188,6 +192,11 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
188192
return {id: id_abbreviation(id) for id in function_ids}
189193

190194

195+
@dataclasses.dataclass
196+
class InlinePassResult(ir.passes.PassResult):
197+
id_count: dict[ir.OperatorIdentifier, int]
198+
199+
191200
class InlinePass(ir.passes.InPlacePass):
192201
def __init__(self) -> None:
193202
super().__init__()
@@ -206,11 +215,11 @@ def _reset(self, model: ir.Model) -> None:
206215
self.used_node_names = set()
207216
self.node_context = {}
208217

209-
def call(self, model: ir.Model) -> ir.passes.PassResult:
218+
def call(self, model: ir.Model) -> InlinePassResult:
210219
self._reset(model)
211-
modified = self.inline_calls_in(model.graph)
220+
id_count = self._inline_calls_in(model.graph)
212221
model.functions.clear()
213-
return ir.passes.PassResult(model, modified)
222+
return InlinePassResult(model, modified=bool(id_count), id_count=id_count)
214223

215224
def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
216225
id = node.op_identifier()
@@ -235,7 +244,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
235244
if default_attr_values:
236245
attributes = {**attributes, **default_attr_values}
237246
if any(
238-
attr.type == ir.AttributeType.GRAPH or attr.type == ir.AttributeType.GRAPHS
247+
attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
239248
for attr in attributes.values()
240249
):
241250
raise ValueError(
@@ -264,7 +273,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
264273
output_values = [value_map[output] for output in function.outputs]
265274
return nodes, output_values # type: ignore
266275

267-
def inline_calls_in(self, graph: ir.Graph) -> bool:
276+
def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
268277
for input in graph.inputs:
269278
if input.name is not None:
270279
self.used_value_names.add(input.name)
@@ -300,7 +309,7 @@ def inline_calls_in(self, graph: ir.Graph) -> bool:
300309
self._function_id_abbreviations[id] + call_site_prefix
301310
)
302311
nodes, values = self._instantiate_call(node, call_site)
303-
ir_convenience.replace_nodes_and_values(
312+
_ir_convenience.replace_nodes_and_values(
304313
graph,
305314
insertion_point=node,
306315
old_nodes=[node],
@@ -313,14 +322,8 @@ def inline_calls_in(self, graph: ir.Graph) -> bool:
313322
if not isinstance(attr, ir.Attr):
314323
continue
315324
if attr.type == ir.AttributeType.GRAPH:
316-
self.inline_calls_in(attr.as_graph())
325+
self._inline_calls_in(attr.as_graph())
317326
elif attr.type == ir.AttributeType.GRAPHS:
318-
for graph in attr.as_graphs():
319-
self.inline_calls_in(graph)
320-
return bool(id_count)
321-
322-
323-
def inline(model: ir.Model) -> None:
324-
"""Inline all function calls (recursively) in the model."""
325-
if model.functions:
326-
InlinePass()(model)
327+
for g in attr.as_graphs():
328+
self._inline_calls_in(g)
329+
return id_count

onnxscript/optimizer/_inliner_test.py renamed to onnxscript/ir/passes/common/inliner_test.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""Tests for onnxscript.optimizer._inliner"""
3+
"""Tests for the inliner pass."""
44

55
from __future__ import annotations
66

@@ -11,7 +11,7 @@
1111
from onnx import parser
1212

1313
from onnxscript import ir
14-
from onnxscript.optimizer._inliner import inline
14+
from onnxscript.ir.passes.common import inliner
1515

1616

1717
def _name_checker(renameable: Sequence[str] | None) -> Callable[[str, str], bool]:
@@ -46,7 +46,7 @@ def _check(
4646
name_check = _name_checker(renameable)
4747
model_proto = parser.parse_model(input_model)
4848
model_ir = ir.serde.deserialize_model(model_proto)
49-
inline(model_ir)
49+
inliner.InlinePass()(model_ir)
5050
proto = ir.serde.serialize_model(model_ir)
5151
text = onnx.printer.to_text(proto)
5252
print(text)
@@ -68,10 +68,7 @@ def _check(
6868
self.assertTrue(isinstance(value, ir.Attr))
6969
self.assertTrue(isinstance(expected_value, ir.Attr))
7070
self.assertEqual(value.type, expected_value.type)
71-
if (
72-
value.type != ir.AttributeType.GRAPH
73-
and value.type != ir.AttributeType.GRAPHS
74-
):
71+
if value.type not in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
7572
self.assertEqual(value.value, expected_value.value)
7673
else:
7774
self.fail("Graph attributes are not supported yet")

onnxscript/optimizer/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
import onnx
1616

17+
import onnxscript.ir.passes.common.inliner
1718
import onnxscript.ir.passes.common.unused_removal
1819
import onnxscript.optimizer._constant_folding as constant_folding
1920
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
2021
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
2122
from onnxscript import ir
22-
from onnxscript.optimizer._inliner import inline
2323
from onnxscript.optimizer._optimizer import optimize_ir
2424

2525
basic_constant_propagation = constant_folding.basic_constant_propagation
@@ -35,6 +35,12 @@ def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
3535
return legacy_optimizer.optimize(model, *args, **kwargs)
3636

3737

38+
def inline(model: ir.Model) -> None:
39+
"""Inline all function calls (recursively) in the model."""
40+
if model.functions:
41+
onnxscript.ir.passes.common.inliner.InlinePass()(model)
42+
43+
3844
def fold_constants(
3945
model: ir.Model | onnx.ModelProto, *args, **kwargs
4046
) -> constant_folding.FoldConstantsResult | bool:

onnxscript/optimizer/_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import logging
66

77
import onnxscript.ir.passes.common.constant_manipulation
8+
import onnxscript.ir.passes.common.inliner
89
import onnxscript.ir.passes.common.unused_removal
910
from onnxscript import ir, rewriter
10-
from onnxscript.optimizer import _constant_folding, _inliner
11+
from onnxscript.optimizer import _constant_folding
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -35,7 +36,7 @@ def optimize_ir(
3536
outer optimization loop if no change is detected in one iteration.
3637
"""
3738
optimizer_pass = ir.passes.Sequential(
38-
_inliner.InlinePass(),
39+
onnxscript.ir.passes.common.inliner.InlinePass(),
3940
ir.passes.PassManager(
4041
[
4142
_constant_folding.FoldConstantsPass(

onnxscript/version_converter/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
"convert_version",
88
]
99

10+
import onnxscript.optimizer
1011
from onnxscript import ir
11-
from onnxscript.optimizer import _inliner
1212
from onnxscript.version_converter import _version_converter
1313

1414

@@ -17,5 +17,5 @@ def convert_version(model: ir.Model, target_version: int) -> None:
1717

1818
# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
1919
# Hence, we inline all the functions.
20-
_inliner.inline(model)
20+
onnxscript.optimizer.inline(model)
2121
_version_converter.convert_version(model, target_version)

0 commit comments

Comments
 (0)