@@ -446,6 +446,7 @@ def __init__(self, graph: ir.Graph, *, parent: GraphBuilder | None = None) -> No
446446 # visible to subgraphs per the ONNX spec).
447447 if parent is None :
448448 self ._constant_cache : dict [tuple [Any , ir .DataType | None ], ir .Value ] = {}
449+ self ._functions : dict [ir .OperatorIdentifier , ir .Function ] = {}
449450
450451 def opset (self , domain : str , version : int = 1 ) -> OpBuilder :
451452 """Create an OpBuilder bound to the given domain and version."""
@@ -469,6 +470,10 @@ def root(self) -> GraphBuilder:
469470 def graph (self ) -> ir .Graph :
470471 return self ._graph
471472
473+ @property
474+ def functions (self ) -> dict [ir .OperatorIdentifier , ir .Function ]:
475+ return self ._root ._functions
476+
472477 def initializer (
473478 self , tensor : ir .TensorProtocol , name : str | None = None , * , qualify : bool = True
474479 ) -> ir .Value :
@@ -796,12 +801,12 @@ def call_op(
796801 op_type : str ,
797802 inputs : Sequence [ir .Value | ir .TensorProtocol | None ],
798803 kwargs : dict [str , Any ],
804+ / ,
805+ domain : str = "" ,
806+ version : int | None = None ,
807+ outputs : int | Sequence [str | ir .Value ] = 1 ,
799808 ):
800809 """Create an ONNX node and add it to the graph, returning its output value(s)."""
801- domain = kwargs .pop ("_domain" , "" )
802- version = kwargs .pop ("_version" , None )
803- outputs = kwargs .pop ("_outputs" , 1 )
804-
805810 count = self .graph .num_nodes ()
806811 node_name = self ._qualify_node_name (f"{ op_type } _node_{ count } " )
807812
@@ -833,7 +838,54 @@ def call_op(
833838
834839 def call (
835840 self ,
836- function ,
841+ function : ir .Function | onnxscript .OnnxFunction ,
842+ * args ,
843+ _outputs : int | Sequence [str | ir .Value ] | None = None ,
844+ ** kwargs ,
845+ ):
846+ """Call a function as a single function node."""
847+ if isinstance (function , ir .Function ):
848+ graph = function .graph
849+ elif isinstance (function , onnxscript .OnnxFunction ):
850+ graph = function .graph ()
851+ function = function .function_ir
852+ else :
853+ raise TypeError ("Function must be an ir.Function or onnxscript.OnnxFunction" )
854+
855+ if _outputs is None :
856+ _outputs = len (graph .outputs )
857+ output_values = self ._adapt_outputs (_outputs , function .name )
858+
859+ # Adapt inputs similarly to call_op: promote constants/tensors to ir.Value.
860+ adapted_args = [self ._input_to_ir_value (arg ) for arg in args ]
861+
862+ count = self .graph .num_nodes ()
863+ node_name = self ._qualify_node_name (f"{ function .name } _node_{ count } " )
864+
865+ node = ir .node (
866+ op_type = function .name ,
867+ inputs = adapted_args ,
868+ attributes = kwargs or None ,
869+ outputs = output_values ,
870+ domain = function .domain ,
871+ name = node_name ,
872+ overload = function .overload ,
873+ )
874+ # Attach scope metadata to the node
875+ node .metadata_props ["namespace" ] = self ._build_namespace ()
876+ node .metadata_props ["pkg.onnxscript.class_hierarchy" ] = repr (self ._scope_classes ())
877+ node .metadata_props ["pkg.onnxscript.name_scopes" ] = repr (self ._scope_names ())
878+
879+ self .add_node (node )
880+ self ._root ._functions [function .identifier ()] = function
881+
882+ if len (node .outputs ) == 0 :
883+ return ()
884+ return node .outputs if len (node .outputs ) > 1 else node .outputs [0 ]
885+
886+ def call_inline (
887+ self ,
888+ function : ir .Function | onnxscript .OnnxFunction ,
837889 * args ,
838890 _outputs : Sequence [str ] | None = None ,
839891 _prefix : str = "" ,
@@ -842,35 +894,56 @@ def call(
842894 if isinstance (function , ir .Function ):
843895 graph = function .graph
844896 elif isinstance (function , onnxscript .OnnxFunction ):
845- graph = function .graph ()
897+ # TODO(justinchuby): Reason about support for outer-scope values in inlined function bodies.
898+ graph = function .graph ().clone (allow_outer_scope_values = True )
846899 else :
847900 raise TypeError ("Function must be an ir.Function or onnxscript.OnnxFunction" )
848- output_renaming : dict [str , str ] = {}
849901 if _outputs is not None :
850902 if len (_outputs ) != len (graph .outputs ):
851903 raise ValueError (
852- f"Number of provided output names { _outputs } does not match "
904+ f"Number of rovided output names { _outputs } does not match "
853905 f"number of function outputs { len (graph .outputs )} ."
854906 )
855- for output , name in zip (graph .outputs , _outputs ):
856- output_renaming [output .name ] = self ._qualify_value_name (name )
907+ # Compute desired output names before pushing prefix scope so they
908+ # are not affected by the prefix.
909+ desired_output_names : list [str ] = [
910+ self ._qualify_value_name (name ) for name in _outputs
911+ ]
857912 else :
858- for output in graph .outputs :
859- output_renaming [output .name ] = self ._qualify_value_name (output .name )
860- nodes , outputs = _inliner .instantiate (graph , args , kwargs )
913+ desired_output_names = []
914+
861915 if _prefix :
862916 self .push_module (_prefix )
917+
918+ count = self .graph .num_nodes ()
919+ node_name_prefix = self ._qualify_node_name (f"{ function .name } _node_{ count } /" )
920+ nodes , outputs = _inliner .instantiate (graph , args , kwargs , prefix = node_name_prefix )
921+
922+ # Track final output values so we can rename them separately.
923+ # The inliner prefixes all names, which would prevent name-based lookup
924+ # from matching the original graph output names.
925+ output_value_ids = {id (v ) for v in outputs if v is not None }
926+
863927 for node in nodes :
864- node .name = self ._qualify_node_name (node .name )
865928 for output in node .outputs :
866- if output .name :
867- if output .name in output_renaming :
868- output .name = output_renaming [output .name ]
869- else :
870- output .name = self ._qualify_value_name (output .name )
929+ if output .name and id (output ) not in output_value_ids :
930+ output .name = self ._qualify_value_name (output .name )
871931 self .add_node (node )
932+
933+ # Apply names to final output values
934+ if desired_output_names :
935+ for output_val , name in zip (outputs , desired_output_names ):
936+ if output_val is not None :
937+ output_val .name = name
938+ else :
939+ for output_val in outputs :
940+ if output_val is not None and output_val .name :
941+ output_val .name = self ._qualify_value_name (output_val .name )
942+
872943 if _prefix :
873944 self .pop_module ()
945+ if len (outputs ) == 0 :
946+ return ()
874947 return outputs if len (outputs ) > 1 else outputs [0 ]
875948
876949 def push_module (self , module : str , class_name : str = "" ) -> None :
@@ -962,27 +1035,52 @@ def version(self) -> int | None:
9621035 return self ._version
9631036
9641037 def _call_op (self , op_type : str , inputs : Sequence [Any ], kwargs : dict [str , Any ]):
965- if "_domain" not in kwargs :
966- kwargs ["_domain" ] = self ._domain
967- if self ._version is not None and "_version" not in kwargs :
968- kwargs ["_version" ] = self ._version
969- return self ._builder .call_op (op_type , inputs , kwargs )
1038+ domain = kwargs .pop ("_domain" , self ._domain )
1039+ version = kwargs .pop ("_version" , self ._version )
1040+ outputs = kwargs .pop ("_outputs" , 1 )
1041+ return self ._builder .call_op (
1042+ op_type , inputs , kwargs , domain = domain , version = version , outputs = outputs
1043+ )
9701044
9711045 def __getattr__ (self , op_type : str ) -> Callable :
9721046 return lambda * args , ** kwargs : self ._call_op (op_type , args , kwargs )
9731047
9741048 def initializer (self , tensor : ir .TensorProtocol , name : str | None = None ) -> ir .Value :
9751049 return self ._builder .initializer (tensor , name )
9761050
1051+ @property
1052+ def functions (self ) -> dict [ir .OperatorIdentifier , ir .Function ]:
1053+ return self ._builder .functions
1054+
9771055 def call (
1056+ self ,
1057+ function ,
1058+ * args ,
1059+ _outputs : Sequence [str ] | int | None = None ,
1060+ ** kwargs ,
1061+ ):
1062+ """Call a function as a single function node.
1063+
1064+ Args:
1065+ function: The function to call (ir.Function or onnxscript.OnnxFunction).
1066+ *args: Positional arguments to pass to the function.
1067+ _outputs: Optional sequence of output names, or an integer specifying the number of outputs.
1068+ **kwargs: Keyword arguments to pass to the function.
1069+
1070+ Returns:
1071+ The output value(s) from the function call.
1072+ """
1073+ return self ._builder .call (function , * args , _outputs = _outputs , ** kwargs )
1074+
1075+ def call_inline (
9781076 self ,
9791077 function ,
9801078 * args ,
9811079 _outputs : Sequence [str ] | None = None ,
9821080 _prefix : str = "" ,
9831081 ** kwargs ,
9841082 ):
985- """Call a function and inline it into the graph.
1083+ """Inline a function body into the current graph.
9861084
9871085 Args:
9881086 function: The function to call (ir.Function or onnxscript.OnnxFunction).
@@ -993,8 +1091,8 @@ def call(
9931091 **kwargs: Keyword arguments to pass to the function.
9941092
9951093 Returns:
996- The output value(s) from the function call .
1094+ The output value(s) from the inlined function body .
9971095 """
998- return self ._builder .call (
1096+ return self ._builder .call_inline (
9991097 function , * args , _outputs = _outputs , _prefix = _prefix , ** kwargs
10001098 )
0 commit comments