Skip to content

Commit 5a9e00b

Browse files
author
Justin Chu
committed
Update impl
1 parent 50322f9 commit 5a9e00b

File tree

2 files changed

+119
-82
lines changed

2 files changed

+119
-82
lines changed

onnxscript/_internal/builder.py

Lines changed: 78 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -511,13 +511,13 @@ def call_op(
511511
self,
512512
op_type: str,
513513
inputs: Sequence[ir.Value | ir.TensorProtocol],
514-
kwargs: dict[str, Any],
514+
kwargs: dict[str, ir.Value | ir.TensorProtocol],
515+
/,
516+
domain: str = "",
517+
version: int | None = None,
518+
outputs: int | Sequence[str | ir.Value] = 1,
515519
):
516520
"""Create an ONNX node and add it to the graph, returning its output value(s)."""
517-
domain = kwargs.pop("_domain", "")
518-
version = kwargs.pop("_version", None)
519-
outputs = kwargs.pop("_outputs", 1)
520-
521521
count = self.graph.num_nodes()
522522
node_name = self._qualify_node_name(f"{op_type}_node_{count}")
523523

@@ -548,12 +548,49 @@ def call_op(
548548
return node.outputs if len(node.outputs) > 1 else node.outputs[0]
549549

550550
def call(
551+
self,
552+
function: ir.Function | onnxscript.OnnxFunction,
553+
*args,
554+
_outputs: int | Sequence[str | ir.Value] | None = None,
555+
**kwargs,
556+
):
557+
"""Call a function as a single function node."""
558+
if isinstance(function, ir.Function):
559+
graph = function.graph
560+
elif isinstance(function, onnxscript.OnnxFunction):
561+
graph = function.graph()
562+
function = function.function_ir
563+
else:
564+
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
565+
566+
if _outputs is None:
567+
_outputs = len(graph.outputs)
568+
output_values = self._adapt_outputs(_outputs, function.name)
569+
570+
node = ir.node(
571+
op_type=function.name,
572+
inputs=args,
573+
attributes=kwargs or None,
574+
outputs=output_values,
575+
domain=function.domain,
576+
name=self._qualify_node_name(function.name),
577+
)
578+
# Attach scope metadata to the node
579+
node.metadata_props["namespace"] = self._build_namespace()
580+
node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes())
581+
node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names())
582+
583+
self.add_node(node)
584+
self._functions[function.identifier()] = function
585+
586+
return node.outputs if len(node.outputs) > 1 else node.outputs[0]
587+
588+
def call_inline(
551589
self,
552590
function: ir.Function | onnxscript.OnnxFunction,
553591
*args,
554592
_outputs: Sequence[str] | None = None,
555593
_prefix: str = "",
556-
_inline: bool = True,
557594
**kwargs,
558595
):
559596
if isinstance(function, ir.Function):
@@ -575,35 +612,21 @@ def call(
575612
else:
576613
for output in graph.outputs:
577614
output_renaming[output.name] = self._qualify_value_name(output.name)
615+
616+
nodes, outputs = _inliner.instantiate(graph, args, kwargs)
617+
578618
if _prefix:
579619
self.push_module(_prefix)
580620

581-
if _inline:
582-
nodes, outputs = _inliner.instantiate(graph, args, kwargs)
583-
584-
for node in nodes:
585-
node.name = self._qualify_node_name(node.name)
586-
for output in node.outputs:
587-
if output.name:
588-
if output.name in output_renaming:
589-
output.name = output_renaming[output.name]
590-
else:
591-
output.name = self._qualify_value_name(output.name)
592-
self.add_node(node)
593-
else:
594-
node = ir.node(
595-
op_type=function.name,
596-
inputs=args,
597-
attributes=kwargs or None,
598-
outputs=[
599-
ir.Value(name=output_renaming[output.name]) for output in graph.outputs
600-
],
601-
domain=function.domain,
602-
name=self._qualify_node_name(function.name),
603-
)
604-
outputs = node.outputs
621+
for node in nodes:
622+
node.name = self._qualify_node_name(node.name)
623+
for output in node.outputs:
624+
if output.name:
625+
if output.name in output_renaming:
626+
output.name = output_renaming[output.name]
627+
else:
628+
output.name = self._qualify_value_name(output.name)
605629
self.add_node(node)
606-
self._functions[function.identifier()] = function
607630

608631
if _prefix:
609632
self.pop_module()
@@ -714,30 +737,46 @@ def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
714737
return self._builder.functions
715738

716739
def call(
740+
self,
741+
function,
742+
*args,
743+
_outputs: Sequence[str] | int | None = None,
744+
**kwargs,
745+
):
746+
"""Call a function as a single function node.
747+
748+
Args:
749+
function: The function to call (ir.Function or onnxscript.OnnxFunction).
750+
*args: Positional arguments to pass to the function.
751+
_outputs: Optional sequence of output names, or an integer specifying the number of outputs.
752+
**kwargs: Keyword arguments to pass to the function.
753+
754+
Returns:
755+
The output value(s) from the function call.
756+
"""
757+
return self._builder.call(function, *args, _outputs=_outputs, **kwargs)
758+
759+
def call_inline(
717760
self,
718761
function,
719762
*args,
720763
_outputs: Sequence[str] | None = None,
721764
_prefix: str = "",
722-
_inline: bool = True,
723765
**kwargs,
724766
):
725-
"""Call a function and optionally inline it into the graph.
767+
"""Inline a function body into the current graph.
726768
727769
Args:
728770
function: The function to call (ir.Function or onnxscript.OnnxFunction).
729771
*args: Positional arguments to pass to the function.
730772
_outputs: Optional sequence of output names. If provided, must match the
731773
number of function outputs.
732774
_prefix: Optional prefix for module scoping (e.g., "layers.0").
733-
_inline: If True, the function body is inlined into the caller graph instead of being
734-
called as a separate node. When False, the function will be added
735-
to the ``.functions`` dictionary. Defaults to True.
736775
**kwargs: Keyword arguments to pass to the function.
737776
738777
Returns:
739-
The output value(s) from the function call.
778+
The output value(s) from the inlined function body.
740779
"""
741-
return self._builder.call(
742-
function, *args, _outputs=_outputs, _prefix=_prefix, _inline=_inline, **kwargs
780+
return self._builder.call_inline(
781+
function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs
743782
)

0 commit comments

Comments
 (0)