@@ -511,13 +511,13 @@ def call_op(
511511 self ,
512512 op_type : str ,
513513 inputs : Sequence [ir .Value | ir .TensorProtocol ],
514- kwargs : dict [str , Any ],
514+ kwargs : dict [str , ir .Value | ir .TensorProtocol ],
515+ / ,
516+ domain : str = "" ,
517+ version : int | None = None ,
518+ outputs : int | Sequence [str | ir .Value ] = 1 ,
515519 ):
516520 """Create an ONNX node and add it to the graph, returning its output value(s)."""
517- domain = kwargs .pop ("_domain" , "" )
518- version = kwargs .pop ("_version" , None )
519- outputs = kwargs .pop ("_outputs" , 1 )
520-
521521 count = self .graph .num_nodes ()
522522 node_name = self ._qualify_node_name (f"{ op_type } _node_{ count } " )
523523
@@ -548,12 +548,49 @@ def call_op(
548548 return node .outputs if len (node .outputs ) > 1 else node .outputs [0 ]
549549
550550 def call (
551+ self ,
552+ function : ir .Function | onnxscript .OnnxFunction ,
553+ * args ,
554+ _outputs : int | Sequence [str | ir .Value ] | None = None ,
555+ ** kwargs ,
556+ ):
557+ """Call a function as a single function node."""
558+ if isinstance (function , ir .Function ):
559+ graph = function .graph
560+ elif isinstance (function , onnxscript .OnnxFunction ):
561+ graph = function .graph ()
562+ function = function .function_ir
563+ else :
564+ raise TypeError ("Function must be an ir.Function or onnxscript.OnnxFunction" )
565+
566+ if _outputs is None :
567+ _outputs = len (graph .outputs )
568+ output_values = self ._adapt_outputs (_outputs , function .name )
569+
570+ node = ir .node (
571+ op_type = function .name ,
572+ inputs = args ,
573+ attributes = kwargs or None ,
574+ outputs = output_values ,
575+ domain = function .domain ,
576+ name = self ._qualify_node_name (function .name ),
577+ )
578+ # Attach scope metadata to the node
579+ node .metadata_props ["namespace" ] = self ._build_namespace ()
580+ node .metadata_props ["pkg.onnxscript.class_hierarchy" ] = repr (self ._scope_classes ())
581+ node .metadata_props ["pkg.onnxscript.name_scopes" ] = repr (self ._scope_names ())
582+
583+ self .add_node (node )
584+ self ._functions [function .identifier ()] = function
585+
586+ return node .outputs if len (node .outputs ) > 1 else node .outputs [0 ]
587+
588+ def call_inline (
551589 self ,
552590 function : ir .Function | onnxscript .OnnxFunction ,
553591 * args ,
554592 _outputs : Sequence [str ] | None = None ,
555593 _prefix : str = "" ,
556- _inline : bool = True ,
557594 ** kwargs ,
558595 ):
559596 if isinstance (function , ir .Function ):
@@ -575,35 +612,21 @@ def call(
575612 else :
576613 for output in graph .outputs :
577614 output_renaming [output .name ] = self ._qualify_value_name (output .name )
615+
616+ nodes , outputs = _inliner .instantiate (graph , args , kwargs )
617+
578618 if _prefix :
579619 self .push_module (_prefix )
580620
581- if _inline :
582- nodes , outputs = _inliner .instantiate (graph , args , kwargs )
583-
584- for node in nodes :
585- node .name = self ._qualify_node_name (node .name )
586- for output in node .outputs :
587- if output .name :
588- if output .name in output_renaming :
589- output .name = output_renaming [output .name ]
590- else :
591- output .name = self ._qualify_value_name (output .name )
592- self .add_node (node )
593- else :
594- node = ir .node (
595- op_type = function .name ,
596- inputs = args ,
597- attributes = kwargs or None ,
598- outputs = [
599- ir .Value (name = output_renaming [output .name ]) for output in graph .outputs
600- ],
601- domain = function .domain ,
602- name = self ._qualify_node_name (function .name ),
603- )
604- outputs = node .outputs
621+ for node in nodes :
622+ node .name = self ._qualify_node_name (node .name )
623+ for output in node .outputs :
624+ if output .name :
625+ if output .name in output_renaming :
626+ output .name = output_renaming [output .name ]
627+ else :
628+ output .name = self ._qualify_value_name (output .name )
605629 self .add_node (node )
606- self ._functions [function .identifier ()] = function
607630
608631 if _prefix :
609632 self .pop_module ()
@@ -714,30 +737,46 @@ def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
714737 return self ._builder .functions
715738
716739 def call (
740+ self ,
741+ function ,
742+ * args ,
743+ _outputs : Sequence [str ] | int | None = None ,
744+ ** kwargs ,
745+ ):
746+ """Call a function as a single function node.
747+
748+ Args:
749+ function: The function to call (ir.Function or onnxscript.OnnxFunction).
750+ *args: Positional arguments to pass to the function.
751+ _outputs: Optional sequence of output names, or an integer specifying the number of outputs.
752+ **kwargs: Keyword arguments to pass to the function.
753+
754+ Returns:
755+ The output value(s) from the function call.
756+ """
757+ return self ._builder .call (function , * args , _outputs = _outputs , ** kwargs )
758+
759+ def call_inline (
717760 self ,
718761 function ,
719762 * args ,
720763 _outputs : Sequence [str ] | None = None ,
721764 _prefix : str = "" ,
722- _inline : bool = True ,
723765 ** kwargs ,
724766 ):
725- """Call a function and optionally inline it into the graph.
767+ """Inline a function body into the current graph.
726768
727769 Args:
728770 function: The function to call (ir.Function or onnxscript.OnnxFunction).
729771 *args: Positional arguments to pass to the function.
730772 _outputs: Optional sequence of output names. If provided, must match the
731773 number of function outputs.
732774 _prefix: Optional prefix for module scoping (e.g., "layers.0").
733- _inline: If True, the function body is inlined into the caller graph instead of being
734- called as a separate node. When False, the function will be added
735- to the ``.functions`` dictionary. Defaults to True.
736775 **kwargs: Keyword arguments to pass to the function.
737776
738777 Returns:
739- The output value(s) from the function call .
778+ The output value(s) from the inlined function body .
740779 """
741- return self ._builder .call (
742- function , * args , _outputs = _outputs , _prefix = _prefix , _inline = _inline , ** kwargs
780+ return self ._builder .call_inline (
781+ function , * args , _outputs = _outputs , _prefix = _prefix , ** kwargs
743782 )
0 commit comments