55from __future__ import annotations
66
77import collections
8- from typing import Any , OrderedDict , Sequence
8+ from typing import Any , OrderedDict
99
10- from onnxscript ._internal import values
10+ from onnxscript .ir import _schemas
1111
1212
1313def separate_input_attributes_from_arguments (
14- param_schemas : Sequence [ values . ParamSchema ] ,
14+ op_signature : _schemas . OpSignature ,
1515 args ,
1616 kwargs ,
1717 fill_defaults : bool = True ,
@@ -20,7 +20,7 @@ def separate_input_attributes_from_arguments(
2020 """Separate Python args and kwargs into ONNX inputs and attributes.
2121
2222 Args:
23- param_schemas : The parameter schemas of an Op or a OnnxFunction .
23+ op_signature : The operator signature containing parameter information .
2424 args: The Python positional arguments supplied by the caller.
2525 kwargs: The Python keyword arguments supplied by the caller.
2626 fill_defaults: Whether to fill the default values for attributes.
@@ -36,56 +36,61 @@ def separate_input_attributes_from_arguments(
3636 TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
3737 TypeError: When a required input is not provided.
3838 """
39- # args, kwargs and param_schemas should be all in order
39+ # args, kwargs and op_signature.params should be all in order
4040 # user may not specify all inputs or attributes
4141
42- all_param_names = {param .name for param in param_schemas }
42+ all_param_names = {param .name for param in op_signature . params }
4343 extra_kwargs = set (kwargs ).difference (all_param_names )
4444 if extra_kwargs and not allow_extra_kwargs :
4545 raise TypeError (f"Unexpected keyword arguments '{ extra_kwargs } '" )
4646
4747 onnx_inputs = []
4848 onnx_attributes = collections .OrderedDict ()
4949
50- for i , param in enumerate (param_schemas ):
51- if param .is_variadic_input :
50+ for i , param in enumerate (op_signature .params ):
51+ is_input = isinstance (param , _schemas .Parameter )
52+ is_variadic = isinstance (param , _schemas .Parameter ) and param .variadic
53+
54+ if is_variadic :
5255 # Exhaust all remaining args
5356 onnx_inputs .extend (args [i :])
5457 args = []
5558 continue
5659 if i < len (args ):
57- if param . is_input :
60+ if is_input :
5861 onnx_inputs .append (args [i ])
5962 else :
6063 onnx_attributes [param .name ] = args [i ]
6164 elif param .name in kwargs :
62- if param . is_input :
65+ if is_input :
6366 onnx_inputs .append (kwargs [param .name ])
6467 else :
6568 onnx_attributes [param .name ] = kwargs [param .name ]
66- elif (
67- param .is_attribute and param .default is not values ._EmptyDefault # pylint: disable=protected-access
68- ):
69+ elif isinstance (param , _schemas .AttributeParameter ) and param .has_default ():
6970 # User did not provide the attribute
7071 if fill_defaults :
71- onnx_attributes [param .name ] = param .default
72+ # Extract the value from the Attr object
73+ onnx_attributes [param .name ] = param .default .value
7274 elif param .required :
7375 raise TypeError (f"Required input '{ param } ' was not provided" )
7476
7577 return onnx_inputs , onnx_attributes
7678
7779
78- def tag_arguments_with_param_schemas (
79- param_schemas : Sequence [ values . ParamSchema ] ,
80+ def tag_arguments_with_signature (
81+ op_signature : _schemas . OpSignature ,
8082 args ,
8183 kwargs ,
8284 fill_defaults : bool = True ,
8385 allow_extra_kwargs : bool = False ,
84- ) -> tuple [list [tuple [Any , values .ParamSchema ]], dict [str , tuple [Any , values .ParamSchema ]]]:
85- """Tag Python args and kwargs with matching ONNX ParamSchema.
86+ ) -> tuple [
87+ list [tuple [Any , _schemas .Parameter | _schemas .AttributeParameter ]],
88+ dict [str , tuple [Any , _schemas .Parameter | _schemas .AttributeParameter ]],
89+ ]:
90+ """Tag Python args and kwargs with matching ONNX Parameter/AttributeParameter.
8691
8792 Args:
88- param_schemas : The parameter schemas of an Op or a OnnxFunction .
93+ op_signature : The operator signature containing parameter information .
8994 args: The Python positional arguments supplied by the caller.
9095 kwargs: The Python keyword arguments supplied by the caller.
9196 fill_defaults: Whether to fill the default values for attributes.
@@ -94,27 +99,29 @@ def tag_arguments_with_param_schemas(
9499
95100 Returns:
96101 A tuple of two elements:
97- - A list of tuple of Python positional argument and ParamSchema .
102+ - A list of tuple of Python positional argument and Parameter/AttributeParameter .
98103 - An ordered dictionary of Python keyword argument names and tuple of argument
99- value and ParamSchema .
104+ value and Parameter/AttributeParameter .
100105
101106 Raises:
102107 TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
103108 TypeError: When a required input is not provided.
104109 """
105- # args, kwargs and param_schemas should be all in order
110+ # args, kwargs and op_signature.params should be all in order
106111 # user may not specify all inputs or attributes
107112
108- all_param_names = {param .name for param in param_schemas }
113+ all_param_names = {param .name for param in op_signature . params }
109114 extra_kwargs = set (kwargs ).difference (all_param_names )
110115 if extra_kwargs and not allow_extra_kwargs :
111116 raise TypeError (f"Unexpected keyword arguments '{ extra_kwargs } '" )
112117
113- tagged_args : list [tuple [Any , values .ParamSchema ]] = []
114- tagged_kwargs : dict [str , tuple [Any , values .ParamSchema ]] = {}
118+ tagged_args : list [tuple [Any , _schemas .Parameter | _schemas .AttributeParameter ]] = []
119+ tagged_kwargs : dict [str , tuple [Any , _schemas .Parameter | _schemas .AttributeParameter ]] = {}
120+
121+ for i , param in enumerate (op_signature .params ):
122+ is_variadic = isinstance (param , _schemas .Parameter ) and param .variadic
115123
116- for i , param in enumerate (param_schemas ):
117- if param .is_variadic_input :
124+ if is_variadic :
118125 # Exhaust all remaining args
119126 tagged_args .extend ((arg , param ) for arg in args [i :])
120127 args = []
@@ -123,25 +130,30 @@ def tag_arguments_with_param_schemas(
123130 tagged_args .append ((args [i ], param ))
124131 elif param .name in kwargs :
125132 tagged_kwargs [param .name ] = (kwargs [param .name ], param )
126- elif param .default is not values . _EmptyDefault : # pylint: disable=protected-access
133+ elif param .has_default ():
127134 # User did not provide the input/attribute
128135 if fill_defaults :
129- tagged_kwargs [param .name ] = (param .default , param )
136+ default_value = param .default
137+ # Extract value from Attr object if it's an AttributeParameter
138+ if isinstance (param , _schemas .AttributeParameter ):
139+ default_value = param .default .value
140+ tagged_kwargs [param .name ] = (default_value , param )
130141 elif param .required :
131142 raise TypeError (f"Required input/attribute '{ param } ' was not provided" )
132143
133144 return tagged_args , tagged_kwargs
134145
135146
136147def turn_to_kwargs_to_avoid_ordering (
137- param_schemas : Sequence [ values . ParamSchema ] ,
148+ op_signature : _schemas . OpSignature ,
138149 inputs : list [Any ],
139150 attributes : dict [str , Any ],
140151) -> dict [str , Any ]:
141152 """Return the inputs and attributes to the order of the function signature."""
142- for idx , param in enumerate (param_schemas ):
153+ for idx , param in enumerate (op_signature . params ):
143154 if param .name not in attributes :
144- if param .is_variadic_input :
155+ is_variadic = isinstance (param , _schemas .Parameter ) and param .variadic
156+ if is_variadic :
145157 attributes [param .name ] = inputs [idx :]
146158 elif inputs :
147159 attributes [param .name ] = inputs .pop (0 )
0 commit comments