44
55from __future__ import annotations
66
7+ import dataclasses
8+
9+ __all__ = ["InlinePass" , "InlinePassResult" ]
10+
711from collections import defaultdict
812from typing import Iterable , List , Sequence , Tuple
913
10- import onnxscript .ir as ir
11- import onnxscript . ir . convenience as ir_convenience
14+ import onnxscript .ir . convenience as _ir_convenience
15+ from onnxscript import ir
1216
1317# A replacement for a node specifies a list of nodes that replaces the original node,
1418# and a list of values that replaces the original node's outputs.
2226CallStack = List [CallSiteId ]
2327
2428
25- def _make_unique_name (name : str , callstack : CallStack , used_names : set [str ]) -> str :
29+ def _make_unique_name (name : str , callstack : CallStack , used_names : set [str ]) -> str : # pylint: disable=unused-argument
2630 """Generate a unique name from a name, calling-context, and set of used names.
2731
2832 If there is a name clash, we add a numeric suffix to the name to make
@@ -188,6 +192,11 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
188192 return {id : id_abbreviation (id ) for id in function_ids }
189193
190194
195+ @dataclasses .dataclass
196+ class InlinePassResult (ir .passes .PassResult ):
197+ id_count : dict [ir .OperatorIdentifier , int ]
198+
199+
191200class InlinePass (ir .passes .InPlacePass ):
192201 def __init__ (self ) -> None :
193202 super ().__init__ ()
@@ -206,11 +215,11 @@ def _reset(self, model: ir.Model) -> None:
206215 self .used_node_names = set ()
207216 self .node_context = {}
208217
209- def call (self , model : ir .Model ) -> ir . passes . PassResult :
218+ def call (self , model : ir .Model ) -> InlinePassResult :
210219 self ._reset (model )
211- modified = self .inline_calls_in (model .graph )
220+ id_count = self ._inline_calls_in (model .graph )
212221 model .functions .clear ()
213- return ir . passes . PassResult (model , modified )
222+ return InlinePassResult (model , modified = bool ( id_count ), id_count = id_count )
214223
215224 def _instantiate_call (self , node : ir .Node , call_site_id : CallSiteId ) -> NodeReplacement :
216225 id = node .op_identifier ()
@@ -235,7 +244,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
235244 if default_attr_values :
236245 attributes = {** attributes , ** default_attr_values }
237246 if any (
238- attr .type == ir .AttributeType .GRAPH or attr . type == ir .AttributeType .GRAPHS
247+ attr .type in { ir .AttributeType .GRAPH , ir .AttributeType .GRAPHS }
239248 for attr in attributes .values ()
240249 ):
241250 raise ValueError (
@@ -264,7 +273,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
264273 output_values = [value_map [output ] for output in function .outputs ]
265274 return nodes , output_values # type: ignore
266275
267- def inline_calls_in (self , graph : ir .Graph ) -> bool :
276+ def _inline_calls_in (self , graph : ir .Graph ) -> dict [ ir . OperatorIdentifier , int ] :
268277 for input in graph .inputs :
269278 if input .name is not None :
270279 self .used_value_names .add (input .name )
@@ -300,7 +309,7 @@ def inline_calls_in(self, graph: ir.Graph) -> bool:
300309 self ._function_id_abbreviations [id ] + call_site_prefix
301310 )
302311 nodes , values = self ._instantiate_call (node , call_site )
303- ir_convenience .replace_nodes_and_values (
312+ _ir_convenience .replace_nodes_and_values (
304313 graph ,
305314 insertion_point = node ,
306315 old_nodes = [node ],
@@ -313,14 +322,8 @@ def inline_calls_in(self, graph: ir.Graph) -> bool:
313322 if not isinstance (attr , ir .Attr ):
314323 continue
315324 if attr .type == ir .AttributeType .GRAPH :
316- self .inline_calls_in (attr .as_graph ())
325+ self ._inline_calls_in (attr .as_graph ())
317326 elif attr .type == ir .AttributeType .GRAPHS :
318- for graph in attr .as_graphs ():
319- self .inline_calls_in (graph )
320- return bool (id_count )
321-
322-
323- def inline (model : ir .Model ) -> None :
324- """Inline all function calls (recursively) in the model."""
325- if model .functions :
326- InlinePass ()(model )
327+ for g in attr .as_graphs ():
328+ self ._inline_calls_in (g )
329+ return id_count
0 commit comments