Skip to content

Commit e45724c

Browse files
justinchubyCopilot
andauthored
Remove ParamSchema (#2768)
Remove ParamSchema and replace it with the new OpSignature. ## BC breaking The `param_schemas()` methods are removed from Ops and ONNXFuntions. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3adee71 commit e45724c

File tree

9 files changed

+208
-289
lines changed

9 files changed

+208
-289
lines changed

onnxscript/_internal/converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -871,13 +871,13 @@ def _translate_call_expr(
871871
) -> tuple[values.Op, list[ir.Value | None], list[ir.Attr]]:
872872
"""Translates a call-expression."""
873873
callee = self._translate_callee_expr(node.func)
874-
param_schemas = callee.param_schemas()
874+
op_signature = callee.op_signature
875875
# If the callee's schema is available, we use it to determine the inputs and attributes.
876876
# Otherwise, we map named arguments to attributes and positional arguments to inputs.
877-
if param_schemas:
877+
if op_signature:
878878
kwargs = {x.arg: x.value for x in node.keywords}
879879
args, attrs = param_manipulation.separate_input_attributes_from_arguments(
880-
param_schemas, node.args, kwargs, fill_defaults=False
880+
op_signature, node.args, kwargs, fill_defaults=False
881881
)
882882
args = [self._translate_opt_expr(x) for x in args]
883883
attrs = [

onnxscript/_internal/evaluator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from onnxscript import onnx_opset, tensor
2727
from onnxscript._internal import autocast, param_manipulation, utils, values
28+
from onnxscript.ir import _schemas
2829

2930
UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]]
3031

@@ -273,11 +274,11 @@ def eval_function(
273274
args: The positional arguments to the function.
274275
kwargs: The keyword arguments to the function.
275276
"""
276-
param_schemas = function.param_schemas()
277+
op_signature = function.op_signature
277278
# Split happens in the evaluator instead of the OnnxFunction __call__ method
278279
# so that evaluators can control behaviors like whether to fill in default values for attributes.
279-
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas(
280-
param_schemas,
280+
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_signature(
281+
op_signature,
281282
args,
282283
kwargs,
283284
fill_defaults=False,
@@ -287,16 +288,16 @@ def eval_function(
287288
adapted_args: list[ExtendedModeValue] = []
288289
adapted_kwargs: dict[str, ExtendedModeValue] = {}
289290
has_array = False
290-
for arg, param_schema in tagged_args:
291-
if param_schema.is_input:
291+
for arg, param in tagged_args:
292+
if isinstance(param, _schemas.Parameter):
292293
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
293294
has_array = has_array or has_array_
294295
adapted_args.append(adapted_arg)
295296
else:
296297
adapted_args.append(arg)
297298

298-
for key, (arg, param_schema) in tagged_kwargs.items():
299-
if param_schema.is_input:
299+
for key, (arg, param) in tagged_kwargs.items():
300+
if isinstance(param, _schemas.Parameter):
300301
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
301302
has_array = has_array or has_array_
302303
adapted_kwargs[key] = adapted_arg

onnxscript/_internal/param_manipulation.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from __future__ import annotations
66

77
import collections
8-
from typing import Any, OrderedDict, Sequence
8+
from typing import Any, OrderedDict
99

10-
from onnxscript._internal import values
10+
from onnxscript.ir import _schemas
1111

1212

1313
def separate_input_attributes_from_arguments(
14-
param_schemas: Sequence[values.ParamSchema],
14+
op_signature: _schemas.OpSignature,
1515
args,
1616
kwargs,
1717
fill_defaults: bool = True,
@@ -20,7 +20,7 @@ def separate_input_attributes_from_arguments(
2020
"""Separate Python args and kwargs into ONNX inputs and attributes.
2121
2222
Args:
23-
param_schemas: The parameter schemas of an Op or a OnnxFunction.
23+
op_signature: The operator signature containing parameter information.
2424
args: The Python positional arguments supplied by the caller.
2525
kwargs: The Python keyword arguments supplied by the caller.
2626
fill_defaults: Whether to fill the default values for attributes.
@@ -36,56 +36,61 @@ def separate_input_attributes_from_arguments(
3636
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
3737
TypeError: When a required input is not provided.
3838
"""
39-
# args, kwargs and param_schemas should be all in order
39+
# args, kwargs and op_signature.params should be all in order
4040
# user may not specify all inputs or attributes
4141

42-
all_param_names = {param.name for param in param_schemas}
42+
all_param_names = {param.name for param in op_signature.params}
4343
extra_kwargs = set(kwargs).difference(all_param_names)
4444
if extra_kwargs and not allow_extra_kwargs:
4545
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")
4646

4747
onnx_inputs = []
4848
onnx_attributes = collections.OrderedDict()
4949

50-
for i, param in enumerate(param_schemas):
51-
if param.is_variadic_input:
50+
for i, param in enumerate(op_signature.params):
51+
is_input = isinstance(param, _schemas.Parameter)
52+
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
53+
54+
if is_variadic:
5255
# Exhaust all remaining args
5356
onnx_inputs.extend(args[i:])
5457
args = []
5558
continue
5659
if i < len(args):
57-
if param.is_input:
60+
if is_input:
5861
onnx_inputs.append(args[i])
5962
else:
6063
onnx_attributes[param.name] = args[i]
6164
elif param.name in kwargs:
62-
if param.is_input:
65+
if is_input:
6366
onnx_inputs.append(kwargs[param.name])
6467
else:
6568
onnx_attributes[param.name] = kwargs[param.name]
66-
elif (
67-
param.is_attribute and param.default is not values._EmptyDefault # pylint: disable=protected-access
68-
):
69+
elif isinstance(param, _schemas.AttributeParameter) and param.has_default():
6970
# User did not provide the attribute
7071
if fill_defaults:
71-
onnx_attributes[param.name] = param.default
72+
# Extract the value from the Attr object
73+
onnx_attributes[param.name] = param.default.value
7274
elif param.required:
7375
raise TypeError(f"Required input '{param}' was not provided")
7476

7577
return onnx_inputs, onnx_attributes
7678

7779

78-
def tag_arguments_with_param_schemas(
79-
param_schemas: Sequence[values.ParamSchema],
80+
def tag_arguments_with_signature(
81+
op_signature: _schemas.OpSignature,
8082
args,
8183
kwargs,
8284
fill_defaults: bool = True,
8385
allow_extra_kwargs: bool = False,
84-
) -> tuple[list[tuple[Any, values.ParamSchema]], dict[str, tuple[Any, values.ParamSchema]]]:
85-
"""Tag Python args and kwargs with matching ONNX ParamSchema.
86+
) -> tuple[
87+
list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
88+
dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
89+
]:
90+
"""Tag Python args and kwargs with matching ONNX Parameter/AttributeParameter.
8691
8792
Args:
88-
param_schemas: The parameter schemas of an Op or a OnnxFunction.
93+
op_signature: The operator signature containing parameter information.
8994
args: The Python positional arguments supplied by the caller.
9095
kwargs: The Python keyword arguments supplied by the caller.
9196
fill_defaults: Whether to fill the default values for attributes.
@@ -94,27 +99,29 @@ def tag_arguments_with_param_schemas(
9499
95100
Returns:
96101
A tuple of two elements:
97-
- A list of tuple of Python positional argument and ParamSchema.
102+
- A list of tuple of Python positional argument and Parameter/AttributeParameter.
98103
- An ordered dictionary of Python keyword argument names and tuple of argument
99-
value and ParamSchema.
104+
value and Parameter/AttributeParameter.
100105
101106
Raises:
102107
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
103108
TypeError: When a required input is not provided.
104109
"""
105-
# args, kwargs and param_schemas should be all in order
110+
# args, kwargs and op_signature.params should be all in order
106111
# user may not specify all inputs or attributes
107112

108-
all_param_names = {param.name for param in param_schemas}
113+
all_param_names = {param.name for param in op_signature.params}
109114
extra_kwargs = set(kwargs).difference(all_param_names)
110115
if extra_kwargs and not allow_extra_kwargs:
111116
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")
112117

113-
tagged_args: list[tuple[Any, values.ParamSchema]] = []
114-
tagged_kwargs: dict[str, tuple[Any, values.ParamSchema]] = {}
118+
tagged_args: list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = []
119+
tagged_kwargs: dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = {}
120+
121+
for i, param in enumerate(op_signature.params):
122+
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
115123

116-
for i, param in enumerate(param_schemas):
117-
if param.is_variadic_input:
124+
if is_variadic:
118125
# Exhaust all remaining args
119126
tagged_args.extend((arg, param) for arg in args[i:])
120127
args = []
@@ -123,25 +130,30 @@ def tag_arguments_with_param_schemas(
123130
tagged_args.append((args[i], param))
124131
elif param.name in kwargs:
125132
tagged_kwargs[param.name] = (kwargs[param.name], param)
126-
elif param.default is not values._EmptyDefault: # pylint: disable=protected-access
133+
elif param.has_default():
127134
# User did not provide the input/attribute
128135
if fill_defaults:
129-
tagged_kwargs[param.name] = (param.default, param)
136+
default_value = param.default
137+
# Extract value from Attr object if it's an AttributeParameter
138+
if isinstance(param, _schemas.AttributeParameter):
139+
default_value = param.default.value
140+
tagged_kwargs[param.name] = (default_value, param)
130141
elif param.required:
131142
raise TypeError(f"Required input/attribute '{param}' was not provided")
132143

133144
return tagged_args, tagged_kwargs
134145

135146

136147
def turn_to_kwargs_to_avoid_ordering(
137-
param_schemas: Sequence[values.ParamSchema],
148+
op_signature: _schemas.OpSignature,
138149
inputs: list[Any],
139150
attributes: dict[str, Any],
140151
) -> dict[str, Any]:
141152
"""Return the inputs and attributes to the order of the function signature."""
142-
for idx, param in enumerate(param_schemas):
153+
for idx, param in enumerate(op_signature.params):
143154
if param.name not in attributes:
144-
if param.is_variadic_input:
155+
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
156+
if is_variadic:
145157
attributes[param.name] = inputs[idx:]
146158
elif inputs:
147159
attributes[param.name] = inputs.pop(0)

0 commit comments

Comments
 (0)