Skip to content

Commit 46e6f69

Browse files
gramalingamCopilot
andauthored
Implement schema-based input/attribute partitioning in GraphBuilder (#2837)
- Convert onnx.defs.OpSchema to ir.schemas.OpSignature via from_op_schema and delegate to separate_input_attributes_from_arguments - Add allow_extra_args parameter to separate_input_attributes_from_arguments for rejecting unexpected positional arguments (default True for compat) - Builder uses strict mode: allow_extra_kwargs=False, allow_extra_args=False - Refactor _build test helper: accept TypeSpec, optional trace_function, return ir.Graph directly - Add comprehensive tests for input/attribute partitioning --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent bab4f28 commit 46e6f69

File tree

3 files changed

+215
-51
lines changed

3 files changed

+215
-51
lines changed

onnxscript/_internal/builder.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import onnxscript._internal._inference as inference
1919
import onnxscript.optimizer
20-
from onnxscript._internal import _inliner
20+
from onnxscript._internal import _inliner, param_manipulation
2121

2222
# A permissible value for an op input, which can be converted to an ir.Value.
2323
VALUE_LIKE = Union[
@@ -359,9 +359,16 @@ def _partition_inputs_attributes(
359359
inputs: Sequence[ir.Value | ir.TensorProtocol],
360360
kwargs: dict[str, Any],
361361
) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]:
362-
# Not implemented yet
363-
del schema
364-
return inputs, kwargs
362+
if schema is None:
363+
return inputs, kwargs
364+
op_signature = ir.schemas.OpSignature.from_op_schema(schema)
365+
return param_manipulation.separate_input_attributes_from_arguments(
366+
op_signature,
367+
list(inputs),
368+
kwargs,
369+
fill_defaults=False,
370+
allow_extra_args=False,
371+
)
365372

366373
def _cast_inputs(
367374
self,

onnxscript/_internal/builder_test.py

Lines changed: 191 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,35 @@
1212
import onnxscript._internal.builder as builder
1313
import onnxscript.testing
1414
from onnxscript import script
15-
from onnxscript.onnx_types import DOUBLE, FLOAT
15+
from onnxscript.onnx_types import DOUBLE, FLOAT, INT64
1616

1717
_default_opset_version = 23
1818

1919

20+
def _resolve_type_spec(spec: builder.TypeSpec) -> ir.TypeAndShape:
21+
"""Convert a *TypeSpec* to an :class:`ir.TypeAndShape`.
22+
23+
Accepts either an :class:`ir.TypeAndShape` directly, or a
24+
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]``
25+
or ``FLOAT['M', 'N']``).
26+
27+
NOTE: This is a local copy of :func:`builder._resolve_type_spec` so that
28+
tests do not reference a private helper directly.
29+
"""
30+
from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel
31+
32+
if isinstance(spec, ir.TypeAndShape):
33+
return spec
34+
if isinstance(spec, type) and issubclass(spec, TensorType):
35+
return spec.to_ir_type_and_shape()
36+
raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.")
37+
38+
2039
def _build(
21-
trace_function,
22-
input_types: Sequence[ir.TypeAndShape],
23-
output_types: Sequence[ir.TypeAndShape],
24-
) -> ir.Model:
40+
input_types: Sequence[builder.TypeSpec],
41+
trace_function=None,
42+
output_types: Sequence[builder.TypeSpec] | None = None,
43+
) -> ir.Graph:
2544
graph = ir.Graph(
2645
name="test_model",
2746
inputs=[],
@@ -30,25 +49,29 @@ def _build(
3049
opset_imports={"": _default_opset_version},
3150
)
3251

33-
onnx_model = ir.Model(graph=graph, ir_version=10)
52+
resolved_inputs = [_resolve_type_spec(t) for t in input_types]
53+
for i, ts in enumerate(resolved_inputs):
54+
graph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape))
3455

35-
for i, input_type in enumerate(input_types):
36-
input_name = f"input_{i}"
37-
graph.inputs.append(ir.Value(name=input_name, type=input_type))
56+
if trace_function is not None:
57+
graph_builder = builder.GraphBuilder(graph)
58+
outputs = trace_function(graph_builder.op, *graph.inputs)
59+
if not isinstance(outputs, Sequence):
60+
outputs = [outputs]
3861

39-
graph_builder = builder.GraphBuilder(graph)
40-
outputs = trace_function(graph_builder.op, *graph.inputs)
41-
if not isinstance(outputs, Sequence):
42-
outputs = [outputs]
43-
if len(outputs) != len(output_types):
44-
raise ValueError(f"Expected {len(output_types)} outputs, but got {len(outputs)}.")
45-
for output, output_type in zip(outputs, output_types):
46-
output.type = output_type.type # TODO: need merge_type method in ir.Value
47-
output.merge_shapes(output_type.shape)
62+
if output_types is not None:
63+
resolved_outputs = [_resolve_type_spec(t) for t in output_types]
64+
if len(outputs) != len(resolved_outputs):
65+
raise ValueError(
66+
f"Expected {len(resolved_outputs)} outputs, but got {len(outputs)}."
67+
)
68+
for output, ts in zip(outputs, resolved_outputs):
69+
output.type = ts.type
70+
output.merge_shapes(ts.shape)
4871

49-
graph.outputs.extend(outputs)
72+
graph.outputs.extend(outputs)
5073

51-
return onnx_model
74+
return graph
5275

5376

5477
def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value]:
@@ -57,24 +80,7 @@ def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value
5780
Returns:
5881
A tuple of (op_builder, input_x, input_y).
5982
"""
60-
graph = ir.Graph(
61-
name="test_model",
62-
inputs=[],
63-
outputs=[],
64-
nodes=[],
65-
opset_imports={"": 23},
66-
)
67-
68-
for i in range(2):
69-
input_name = f"input_{i}"
70-
graph.inputs.append(
71-
ir.Value(
72-
name=input_name,
73-
type=ir.TensorType(ir.DataType.FLOAT),
74-
shape=ir.Shape([2, 3, 4]),
75-
)
76-
)
77-
83+
graph = _build(input_types=[FLOAT[2, 3, 4], FLOAT[2, 3, 4]])
7884
graph_builder = builder.GraphBuilder(graph)
7985
x, y = graph.inputs
8086
return graph_builder.op, x, y
@@ -89,12 +95,11 @@ def _add_mul_add(op: builder.OpBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
8995
return z
9096

9197
float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4]))
92-
model = _build(
93-
_add_mul_add,
98+
graph = _build(
9499
input_types=[float_2d, float_2d],
100+
trace_function=_add_mul_add,
95101
output_types=[float_2d],
96102
)
97-
graph = model.graph
98103
# Expect exactly 3 nodes: Add, Mul, Add
99104
op_types = [node.op_type for node in graph]
100105
self.assertEqual(op_types, ["Add", "Mul", "Add"])
@@ -121,12 +126,11 @@ def _add_with_custom_names(
121126
return z
122127

123128
float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4]))
124-
model = _build(
125-
_add_with_custom_names,
129+
graph = _build(
126130
input_types=[float_2d, float_2d],
131+
trace_function=_add_with_custom_names,
127132
output_types=[float_2d],
128133
)
129-
graph = model.graph
130134

131135
# Verify that the nodes have outputs with the specified names
132136
nodes = list(graph)
@@ -207,12 +211,11 @@ def _ops_with_default_names(
207211
return z
208212

209213
float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4]))
210-
model = _build(
211-
_ops_with_default_names,
214+
graph = _build(
212215
input_types=[float_2d, float_2d],
216+
trace_function=_ops_with_default_names,
213217
output_types=[float_2d],
214218
)
215-
graph = model.graph
216219

217220
# Verify the nodes use the new naming strategy
218221
nodes = list(graph)
@@ -1026,5 +1029,146 @@ def test_build_graph_custom_name(self):
10261029
self.assertEqual(graph.name, "loop_body")
10271030

10281031

1032+
class PartitionInputsAttributesTest(unittest.TestCase):
1033+
"""Tests for GraphBuilder._partition_inputs_attributes."""
1034+
1035+
def test_unknown_op_passes_inputs_and_kwargs_through(self):
1036+
"""An unknown op has no schema, so inputs and kwargs pass through unchanged."""
1037+
1038+
def _dummy(op, x, y):
1039+
return op.DummyOp(x, y, alpha=1.0)
1040+
1041+
graph = _build(
1042+
input_types=[FLOAT[3, 4], FLOAT[3, 4]],
1043+
trace_function=_dummy,
1044+
)
1045+
x, y = graph.inputs
1046+
node = graph.node(0)
1047+
self.assertEqual(node.op_type, "DummyOp")
1048+
self.assertEqual(list(node.inputs), [x, y])
1049+
self.assertEqual(node.attributes["alpha"].as_float(), 1.0)
1050+
1051+
def test_op_with_only_inputs(self):
1052+
"""Add has two inputs and no attributes."""
1053+
1054+
def _add(op, x, y):
1055+
return op.Add(x, y)
1056+
1057+
graph = _build(
1058+
input_types=[FLOAT[3, 4], FLOAT[3, 4]],
1059+
trace_function=_add,
1060+
)
1061+
x, y = graph.inputs
1062+
node = graph.node(0)
1063+
self.assertEqual(node.op_type, "Add")
1064+
self.assertEqual(list(node.inputs), [x, y])
1065+
self.assertEqual(len(node.attributes), 0)
1066+
1067+
def test_op_with_inputs_and_attributes_in_kwargs(self):
1068+
"""Gemm has 3 inputs (A, B, C) and attributes (alpha, beta, transA, transB)."""
1069+
1070+
def _gemm(op, a, b, c):
1071+
return op.Gemm(a, b, c, alpha=2.0, transB=1)
1072+
1073+
graph = _build(
1074+
input_types=[FLOAT[3, 4], FLOAT[4, 5], FLOAT[3, 5]],
1075+
trace_function=_gemm,
1076+
)
1077+
a, b, c = graph.inputs
1078+
node = graph.node(0)
1079+
self.assertEqual(node.op_type, "Gemm")
1080+
self.assertEqual(list(node.inputs), [a, b, c])
1081+
self.assertEqual(node.attributes["alpha"].as_float(), 2.0)
1082+
self.assertEqual(node.attributes["transB"].as_int(), 1)
1083+
1084+
def test_op_with_optional_input_omitted(self):
1085+
"""Gemm's third input (C) is optional. Omitting it should work."""
1086+
1087+
def _gemm_no_c(op, a, b):
1088+
return op.Gemm(a, b, alpha=2.0)
1089+
1090+
graph = _build(
1091+
input_types=[FLOAT[3, 4], FLOAT[4, 5]],
1092+
trace_function=_gemm_no_c,
1093+
)
1094+
a, b = graph.inputs
1095+
node = graph.node(0)
1096+
self.assertEqual(node.op_type, "Gemm")
1097+
self.assertEqual(list(node.inputs), [a, b])
1098+
self.assertEqual(node.attributes["alpha"].as_float(), 2.0)
1099+
1100+
def test_does_not_fill_attribute_defaults(self):
1101+
"""Attribute defaults should not be filled in (fill_defaults=False)."""
1102+
1103+
def _gemm_no_attrs(op, a, b):
1104+
return op.Gemm(a, b)
1105+
1106+
graph = _build(
1107+
input_types=[FLOAT[3, 4], FLOAT[4, 5]],
1108+
trace_function=_gemm_no_attrs,
1109+
)
1110+
node = graph.node(0)
1111+
# alpha, beta, transA, transB all have defaults but should NOT appear
1112+
self.assertFalse(node.attributes)
1113+
1114+
def test_variadic_inputs_with_attribute(self):
1115+
"""Concat has variadic inputs and an axis attribute."""
1116+
1117+
def _concat(op, x, y, z):
1118+
return op.Concat(x, y, z, axis=0)
1119+
1120+
graph = _build(
1121+
input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]],
1122+
trace_function=_concat,
1123+
)
1124+
x, y, z = graph.inputs
1125+
node = graph.node(0)
1126+
self.assertEqual(node.op_type, "Concat")
1127+
self.assertEqual(list(node.inputs), [x, y, z])
1128+
self.assertEqual(node.attributes["axis"].as_int(), 0)
1129+
1130+
def test_slice_kwargs_are_correctly_ordered_as_inputs(self):
1131+
"""Calling op.Slice with keyword arguments should place them in schema order."""
1132+
1133+
def _slice(op, data, starts, ends, axes, steps):
1134+
# Pass optional inputs as kwargs in non-schema order
1135+
return op.Slice(data, ends=ends, steps=steps, starts=starts, axes=axes)
1136+
1137+
graph = _build(
1138+
input_types=[FLOAT[20, 10], INT64[2], INT64[2], INT64[2], INT64[2]],
1139+
trace_function=_slice,
1140+
)
1141+
data, starts, ends, axes, steps = graph.inputs
1142+
1143+
slice_node = graph.node(0)
1144+
self.assertEqual(slice_node.op_type, "Slice")
1145+
# Schema order: data, starts, ends, axes, steps
1146+
self.assertEqual(list(slice_node.inputs), [data, starts, ends, axes, steps])
1147+
1148+
def test_omitting_required_input_raises(self):
1149+
"""Omitting a required input should raise TypeError."""
1150+
1151+
def _add_missing_input(op, x):
1152+
return op.Add(x)
1153+
1154+
with self.assertRaises(TypeError):
1155+
_build(
1156+
input_types=[FLOAT[3, 4]],
1157+
trace_function=_add_missing_input,
1158+
)
1159+
1160+
def test_extra_inputs_raises(self):
1161+
"""Extra positional inputs beyond the schema should raise TypeError."""
1162+
1163+
def _add_extra_input(op, x, y, z):
1164+
return op.Add(x, y, z)
1165+
1166+
with self.assertRaises(TypeError):
1167+
_build(
1168+
input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]],
1169+
trace_function=_add_extra_input,
1170+
)
1171+
1172+
10291173
if __name__ == "__main__":
10301174
unittest.main()

onnxscript/_internal/param_manipulation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def separate_input_attributes_from_arguments(
1616
kwargs,
1717
fill_defaults: bool = True,
1818
allow_extra_kwargs: bool = False,
19+
allow_extra_args: bool = True,
1920
) -> tuple[list[Any], OrderedDict[str, Any]]:
2021
"""Separate Python args and kwargs into ONNX inputs and attributes.
2122
@@ -26,6 +27,9 @@ def separate_input_attributes_from_arguments(
2627
fill_defaults: Whether to fill the default values for attributes.
2728
allow_extra_kwargs: Whether to allow extra keyword arguments.
2829
When set to True, extra/unknown arguments will be ignored.
30+
allow_extra_args: Whether to allow extra positional arguments beyond
31+
what the schema declares (when no variadic parameter exists).
32+
When set to False, a TypeError is raised for extra args.
2933
3034
Returns:
3135
A tuple of two elements:
@@ -34,6 +38,7 @@ def separate_input_attributes_from_arguments(
3438
3539
Raises:
3640
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
41+
TypeError: When allow_extra_args is False and there are extra positional args.
3742
TypeError: When a required input is not provided.
3843
"""
3944
# args, kwargs and op_signature.params should be all in order
@@ -46,12 +51,14 @@ def separate_input_attributes_from_arguments(
4651

4752
onnx_inputs = []
4853
onnx_attributes = collections.OrderedDict()
54+
has_variadic = False
4955

5056
for i, param in enumerate(op_signature.params):
5157
is_input = param.is_param()
5258
is_variadic = is_input and param.variadic
5359

5460
if is_variadic:
61+
has_variadic = True
5562
# Exhaust all remaining args
5663
onnx_inputs.extend(args[i:])
5764
args = []
@@ -74,6 +81,12 @@ def separate_input_attributes_from_arguments(
7481
elif param.required:
7582
raise TypeError(f"Required input '{param}' was not provided")
7683

84+
if not allow_extra_args and not has_variadic and len(args) > len(op_signature.params):
85+
raise TypeError(
86+
f"Too many positional arguments: expected {len(op_signature.params)}, "
87+
f"got {len(args)}"
88+
)
89+
7790
return onnx_inputs, onnx_attributes
7891

7992

0 commit comments

Comments
 (0)