Skip to content

Commit df97c94

Browse files
authored
Add an option to not inline a function when building the graph (#2851)
- Introduced distinct `call` and `call_inline` methods in `GraphBuilder` and `OpBuilder` to differentiate between creating a single function call node (`call`) and inlining a function's body directly into the graph (`call_inline`). The `call` method now registers the function in the builder, while `call_inline` does not. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 90f754a commit df97c94

File tree

3 files changed

+287
-45
lines changed

3 files changed

+287
-45
lines changed

onnxscript/_internal/builder.py

Lines changed: 125 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ def __init__(self, graph: ir.Graph, *, parent: GraphBuilder | None = None) -> No
446446
# visible to subgraphs per the ONNX spec).
447447
if parent is None:
448448
self._constant_cache: dict[tuple[Any, ir.DataType | None], ir.Value] = {}
449+
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
449450

450451
def opset(self, domain: str, version: int = 1) -> OpBuilder:
451452
"""Create an OpBuilder bound to the given domain and version."""
@@ -469,6 +470,10 @@ def root(self) -> GraphBuilder:
469470
def graph(self) -> ir.Graph:
470471
return self._graph
471472

473+
@property
474+
def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
475+
return self._root._functions
476+
472477
def initializer(
473478
self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True
474479
) -> ir.Value:
@@ -796,12 +801,12 @@ def call_op(
796801
op_type: str,
797802
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
798803
kwargs: dict[str, Any],
804+
/,
805+
domain: str = "",
806+
version: int | None = None,
807+
outputs: int | Sequence[str | ir.Value] = 1,
799808
):
800809
"""Create an ONNX node and add it to the graph, returning its output value(s)."""
801-
domain = kwargs.pop("_domain", "")
802-
version = kwargs.pop("_version", None)
803-
outputs = kwargs.pop("_outputs", 1)
804-
805810
count = self.graph.num_nodes()
806811
node_name = self._qualify_node_name(f"{op_type}_node_{count}")
807812

@@ -833,7 +838,54 @@ def call_op(
833838

834839
def call(
835840
self,
836-
function,
841+
function: ir.Function | onnxscript.OnnxFunction,
842+
*args,
843+
_outputs: int | Sequence[str | ir.Value] | None = None,
844+
**kwargs,
845+
):
846+
"""Call a function as a single function node."""
847+
if isinstance(function, ir.Function):
848+
graph = function.graph
849+
elif isinstance(function, onnxscript.OnnxFunction):
850+
graph = function.graph()
851+
function = function.function_ir
852+
else:
853+
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
854+
855+
if _outputs is None:
856+
_outputs = len(graph.outputs)
857+
output_values = self._adapt_outputs(_outputs, function.name)
858+
859+
# Adapt inputs similarly to call_op: promote constants/tensors to ir.Value.
860+
adapted_args = [self._input_to_ir_value(arg) for arg in args]
861+
862+
count = self.graph.num_nodes()
863+
node_name = self._qualify_node_name(f"{function.name}_node_{count}")
864+
865+
node = ir.node(
866+
op_type=function.name,
867+
inputs=adapted_args,
868+
attributes=kwargs or None,
869+
outputs=output_values,
870+
domain=function.domain,
871+
name=node_name,
872+
overload=function.overload,
873+
)
874+
# Attach scope metadata to the node
875+
node.metadata_props["namespace"] = self._build_namespace()
876+
node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes())
877+
node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names())
878+
879+
self.add_node(node)
880+
self._root._functions[function.identifier()] = function
881+
882+
if len(node.outputs) == 0:
883+
return ()
884+
return node.outputs if len(node.outputs) > 1 else node.outputs[0]
885+
886+
def call_inline(
887+
self,
888+
function: ir.Function | onnxscript.OnnxFunction,
837889
*args,
838890
_outputs: Sequence[str] | None = None,
839891
_prefix: str = "",
@@ -842,35 +894,56 @@ def call(
842894
if isinstance(function, ir.Function):
843895
graph = function.graph
844896
elif isinstance(function, onnxscript.OnnxFunction):
845-
graph = function.graph()
897+
# TODO(justinchuby): Reason about support for outer-scope values in inlined function bodies.
898+
graph = function.graph().clone(allow_outer_scope_values=True)
846899
else:
847900
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
848-
output_renaming: dict[str, str] = {}
849901
if _outputs is not None:
850902
if len(_outputs) != len(graph.outputs):
851903
raise ValueError(
852-
f"Number of provided output names {_outputs} does not match "
904+
f"Number of rovided output names {_outputs} does not match "
853905
f"number of function outputs {len(graph.outputs)}."
854906
)
855-
for output, name in zip(graph.outputs, _outputs):
856-
output_renaming[output.name] = self._qualify_value_name(name)
907+
# Compute desired output names before pushing prefix scope so they
908+
# are not affected by the prefix.
909+
desired_output_names: list[str] = [
910+
self._qualify_value_name(name) for name in _outputs
911+
]
857912
else:
858-
for output in graph.outputs:
859-
output_renaming[output.name] = self._qualify_value_name(output.name)
860-
nodes, outputs = _inliner.instantiate(graph, args, kwargs)
913+
desired_output_names = []
914+
861915
if _prefix:
862916
self.push_module(_prefix)
917+
918+
count = self.graph.num_nodes()
919+
node_name_prefix = self._qualify_node_name(f"{function.name}_node_{count}/")
920+
nodes, outputs = _inliner.instantiate(graph, args, kwargs, prefix=node_name_prefix)
921+
922+
# Track final output values so we can rename them separately.
923+
# The inliner prefixes all names, which would prevent name-based lookup
924+
# from matching the original graph output names.
925+
output_value_ids = {id(v) for v in outputs if v is not None}
926+
863927
for node in nodes:
864-
node.name = self._qualify_node_name(node.name)
865928
for output in node.outputs:
866-
if output.name:
867-
if output.name in output_renaming:
868-
output.name = output_renaming[output.name]
869-
else:
870-
output.name = self._qualify_value_name(output.name)
929+
if output.name and id(output) not in output_value_ids:
930+
output.name = self._qualify_value_name(output.name)
871931
self.add_node(node)
932+
933+
# Apply names to final output values
934+
if desired_output_names:
935+
for output_val, name in zip(outputs, desired_output_names):
936+
if output_val is not None:
937+
output_val.name = name
938+
else:
939+
for output_val in outputs:
940+
if output_val is not None and output_val.name:
941+
output_val.name = self._qualify_value_name(output_val.name)
942+
872943
if _prefix:
873944
self.pop_module()
945+
if len(outputs) == 0:
946+
return ()
874947
return outputs if len(outputs) > 1 else outputs[0]
875948

876949
def push_module(self, module: str, class_name: str = "") -> None:
@@ -962,27 +1035,52 @@ def version(self) -> int | None:
9621035
return self._version
9631036

9641037
def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
965-
if "_domain" not in kwargs:
966-
kwargs["_domain"] = self._domain
967-
if self._version is not None and "_version" not in kwargs:
968-
kwargs["_version"] = self._version
969-
return self._builder.call_op(op_type, inputs, kwargs)
1038+
domain = kwargs.pop("_domain", self._domain)
1039+
version = kwargs.pop("_version", self._version)
1040+
outputs = kwargs.pop("_outputs", 1)
1041+
return self._builder.call_op(
1042+
op_type, inputs, kwargs, domain=domain, version=version, outputs=outputs
1043+
)
9701044

9711045
def __getattr__(self, op_type: str) -> Callable:
9721046
return lambda *args, **kwargs: self._call_op(op_type, args, kwargs)
9731047

9741048
def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
9751049
return self._builder.initializer(tensor, name)
9761050

1051+
@property
1052+
def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
1053+
return self._builder.functions
1054+
9771055
def call(
1056+
self,
1057+
function,
1058+
*args,
1059+
_outputs: Sequence[str] | int | None = None,
1060+
**kwargs,
1061+
):
1062+
"""Call a function as a single function node.
1063+
1064+
Args:
1065+
function: The function to call (ir.Function or onnxscript.OnnxFunction).
1066+
*args: Positional arguments to pass to the function.
1067+
_outputs: Optional sequence of output names, or an integer specifying the number of outputs.
1068+
**kwargs: Keyword arguments to pass to the function.
1069+
1070+
Returns:
1071+
The output value(s) from the function call.
1072+
"""
1073+
return self._builder.call(function, *args, _outputs=_outputs, **kwargs)
1074+
1075+
def call_inline(
9781076
self,
9791077
function,
9801078
*args,
9811079
_outputs: Sequence[str] | None = None,
9821080
_prefix: str = "",
9831081
**kwargs,
9841082
):
985-
"""Call a function and inline it into the graph.
1083+
"""Inline a function body into the current graph.
9861084
9871085
Args:
9881086
function: The function to call (ir.Function or onnxscript.OnnxFunction).
@@ -993,8 +1091,8 @@ def call(
9931091
**kwargs: Keyword arguments to pass to the function.
9941092
9951093
Returns:
996-
The output value(s) from the function call.
1094+
The output value(s) from the inlined function body.
9971095
"""
998-
return self._builder.call(
1096+
return self._builder.call_inline(
9991097
function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs
10001098
)

0 commit comments

Comments
 (0)