Skip to content

Commit 0ed5f23

Browse files
justinchubyCopilot
andauthored
Replace op_schema with op_signature (#2771)
This pull request refactors how operator schemas are handled throughout the autocast, converter, and evaluator modules. The main change is replacing direct usage of `OpSchema` with a new `_schemas.OpSignature` abstraction, leading to more consistent and modular code when dealing with operator signatures, especially for input casting and evaluation. Several related methods are renamed and refactored for clarity and encapsulation. ## Important changes - The `Evaluator` interface now defines `eval_op` on onnx ops. The old `eval` was removed in favor of a more flexible `eval_op`. The exporter's eval will continue to function with a compatible logic in `class Op` - `op_schema` properties from Functions are removed **Operator signature abstraction and autocast refactor:** * Replaced usage of `OpSchema` with `_schemas.OpSignature` in `onnxscript/_internal/autocast.py`, updating all relevant function signatures and internal logic to use the new abstraction. This includes changing how input parameters are filtered and type constraints are accessed. **AST Converter integration:** * Updated the converter (`onnxscript/_internal/converter.py`) to pass `op_signature` instead of `op_schema` to autocast functions, ensuring compatibility with the new signature abstraction. **Evaluator refactor and encapsulation:** * Refactored the evaluator (`onnxscript/_internal/evaluator.py`) to use `_adapt_inputs`, `_adapt_attributes`, and `_adapt_outputs` methods, encapsulating the logic for adapting inputs/outputs and removing unused or redundant methods. Now, operator signatures are consistently adapted from `OpSchema` using `_schemas.OpSignature`. Addtionally: * Clean up typing annotation utilities * Now supply IR attrs directly when creating attributes to avoid proto serialization and loss of subgraph references --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent cd185e8 commit 0ed5f23

File tree

16 files changed

+254
-696
lines changed

16 files changed

+254
-696
lines changed

docs/examples/06_plot_model_local_funs.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,3 @@ def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]: # noqa: F821
4747

4848
model = l2norm.to_model_proto()
4949
print(onnx.printer.to_text(model))
50-
51-
# %%
52-
# Let's now explicitly specify which functions to include.
53-
# First, generate a model with no model-local functions:
54-
55-
model = l2norm.to_model_proto(functions=[])
56-
print(onnx.printer.to_text(model))
57-
58-
# %%
59-
# Now, generate a model with one model-local function:
60-
61-
model = l2norm.to_model_proto(functions=[sum])
62-
print(onnx.printer.to_text(model))

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"packaging",
4242
"protobuf",
4343
)
44-
ONNX_IR = "onnx_ir==0.1.13"
44+
ONNX_IR = "onnx_ir==0.1.15"
4545
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"
4646

4747

onnxscript/_internal/autocast.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
import numpy as np
99
import onnx
10-
import onnx.helper # noqa: TID251
11-
from onnx.defs import OpSchema
1210

1311
from onnxscript import ir, tensor
12+
from onnxscript.ir import _schemas
1413

1514
if TYPE_CHECKING:
1615
from onnxscript._internal import converter
@@ -20,23 +19,15 @@
2019
# python values into ONNX TensorProto, while the runtime converts python values into
2120
# ONNXScript runtime's value-representation (based on Tensor).
2221

23-
24-
# Utilities to convert a python value to TensorProto (for use by the script converter)
25-
26-
27-
def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue):
28-
return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name))
29-
30-
3122
_REPEATED_ATTRIBUTE_TYPES = frozenset(
3223
{
33-
onnx.AttributeProto.FLOATS,
34-
onnx.AttributeProto.INTS,
35-
onnx.AttributeProto.STRINGS,
36-
onnx.AttributeProto.TENSORS,
37-
onnx.AttributeProto.GRAPHS,
38-
onnx.AttributeProto.SPARSE_TENSORS,
39-
onnx.AttributeProto.TYPE_PROTOS,
24+
ir.AttributeType.FLOATS,
25+
ir.AttributeType.INTS,
26+
ir.AttributeType.STRINGS,
27+
ir.AttributeType.TENSORS,
28+
ir.AttributeType.GRAPHS,
29+
ir.AttributeType.SPARSE_TENSORS,
30+
ir.AttributeType.TYPE_PROTOS,
4031
}
4132
)
4233

@@ -45,33 +36,28 @@ def pyvalue_to_onnx_attribute(
4536
key: str,
4637
value: Any,
4738
name_generator: Callable[[], str],
48-
attr_type: onnx.AttributeProto.AttributeType | None = None,
49-
) -> onnx.AttributeProto:
39+
attr_type: ir.AttributeType | None = None,
40+
) -> ir.Attr:
5041
"""Helper function to create an ONNX AttributeProto.
5142
52-
This is a refinement of onnx.helper.make_attribute that works with ONNX Script
53-
conventions for allowed types for attribute-values. In particular, it allows
54-
* Empty lists as attribute values, provided the attribute type is specified
43+
* Empty lists can be attribute values, provided the attribute type is specified
5544
and is a list type.
5645
* Scalar-values like 1.0 as well as lists like [1, -1] to be specified
5746
when the attribute type is TensorProto by automatically converting the value
5847
into a 0-D or 1-D tensor respectively.
5948
"""
49+
# TODO(justinchuby): Remove this function and use onnx-ir directly.
6050
if isinstance(value, list) and not value:
6151
# Empty list value:
6252
if attr_type is None:
6353
raise ValueError("Attribute type must be specified for empty list value.")
6454
if attr_type not in _REPEATED_ATTRIBUTE_TYPES:
6555
raise ValueError("Empty list value is only allowed for repeated attribute types.")
66-
return onnx.AttributeProto(name=key, type=attr_type)
67-
elif attr_type == onnx.AttributeProto.TENSOR and not isinstance(value, onnx.TensorProto):
68-
return onnx.AttributeProto(
69-
name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value)
70-
)
56+
return ir.Attr(name=key, type=attr_type, value=[])
57+
elif attr_type == ir.AttributeType.TENSOR and not isinstance(value, onnx.TensorProto):
58+
return ir.AttrTensor(name=key, value=ir.tensor(value, name=name_generator()))
7159
else:
72-
# When the value is a subgraph, ONNX IR will complain that some values are
73-
# not found from the scope.
74-
return onnx.helper.make_attribute(key, value) # noqa: TID251
60+
return ir.convenience.convert_attribute(key, value, attr_type=attr_type)
7561

7662

7763
# Utilities to convert python values into onnxscript tensors.
@@ -126,7 +112,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None):
126112
def cast_inputs(
127113
get_type_info: Callable[[Any], Any],
128114
cast: Callable[[Any, Any], Any],
129-
op_schema: OpSchema | None,
115+
op_signature: _schemas.OpSignature | None,
130116
args,
131117
) -> tuple[Any, ...]:
132118
"""Uses schema specification to support a limited form of auto-casting.
@@ -140,12 +126,13 @@ def cast_inputs(
140126
This is used by the converter in a static-mode, as well as by the eager-mode
141127
execution in a dynamic-mode.
142128
"""
143-
if op_schema is None:
129+
if op_signature is None:
144130
# Either an error or a custom op.
145131
# No checks/casts in this case.
146132
return tuple(cast(x, None) for x in args)
147133

148-
expected_inputs = op_schema.inputs
134+
# Filter to get only input parameters (not AttributeParameters)
135+
expected_inputs = op_signature.inputs
149136
# We make two passes. In the first pass, we identify known type-bindings for
150137
# type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}.
151138
# In the second pass, we use these bindings to cast scalar-values to
@@ -156,17 +143,17 @@ def cast_inputs(
156143
for i, x in enumerate(args):
157144
if i < len(expected_inputs):
158145
expected = expected_inputs[i]
159-
elif expected_inputs[-1].option == OpSchema.FormalParameterOption.Variadic:
146+
elif expected_inputs[-1].variadic:
160147
expected = expected_inputs[-1]
161-
if not expected.is_homogeneous:
148+
if not expected.homogeneous:
162149
args_typevars.append((x, None))
163150
continue
164151
else:
165152
raise ValueError(
166153
f"Number of actual parameters {len(args)} "
167154
f"exceeds number of formal parameters {len(expected_inputs)}."
168155
)
169-
typevar = expected.type_str
156+
typevar = expected.type_constraint.name
170157
if "(" not in typevar:
171158
# typevar is an identifier, like "T"
172159
typeinfo = get_type_info(x)
@@ -177,18 +164,18 @@ def cast_inputs(
177164
return tuple(cast_args)
178165

179166

180-
def dynamic_cast_inputs(op_schema: OpSchema, args):
167+
def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args):
181168
"""Used for autocast during eager-mode execution."""
182169

183170
def get_type_info(x):
184171
return x.dtype if isinstance(x, tensor.Tensor) else None
185172

186-
return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_schema, args)
173+
return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_signature, args)
187174

188175

189176
def static_cast_inputs(
190177
converter_: converter.Converter,
191-
op_schema: Optional[OpSchema],
178+
op_signature: Optional[_schemas.OpSignature],
192179
args: Sequence[Optional[ir.Value]],
193180
) -> tuple[str, ...]:
194181
"""Used for autocast during script-translation.
@@ -212,4 +199,4 @@ def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
212199
return converter_.emit1([x_cast], "CastLike", [x, y])
213200
return x
214201

215-
return cast_inputs(get_type_info, cast_like, op_schema, args)
202+
return cast_inputs(get_type_info, cast_like, op_signature, args)

0 commit comments

Comments
 (0)