1212import onnxscript ._internal .builder as builder
1313import onnxscript .testing
1414from onnxscript import script
15- from onnxscript .onnx_types import DOUBLE , FLOAT
15+ from onnxscript .onnx_types import DOUBLE , FLOAT , INT64
1616
1717_default_opset_version = 23
1818
1919
20+ def _resolve_type_spec (spec : builder .TypeSpec ) -> ir .TypeAndShape :
21+ """Convert a *TypeSpec* to an :class:`ir.TypeAndShape`.
22+
23+ Accepts either an :class:`ir.TypeAndShape` directly, or a
24+ :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]``
25+ or ``FLOAT['M', 'N']``).
26+
27+ NOTE: This is a local copy of :func:`builder._resolve_type_spec` so that
28+ tests do not reference a private helper directly.
29+ """
30+ from onnxscript .onnx_types import TensorType # pylint: disable=import-outside-toplevel
31+
32+ if isinstance (spec , ir .TypeAndShape ):
33+ return spec
34+ if isinstance (spec , type ) and issubclass (spec , TensorType ):
35+ return spec .to_ir_type_and_shape ()
36+ raise TypeError (f"Expected ir.TypeAndShape or a TensorType subclass, got { type (spec )!r} ." )
37+
38+
2039def _build (
21- trace_function ,
22- input_types : Sequence [ ir . TypeAndShape ] ,
23- output_types : Sequence [ir . TypeAndShape ] ,
24- ) -> ir .Model :
40+ input_types : Sequence [ builder . TypeSpec ] ,
41+ trace_function = None ,
42+ output_types : Sequence [builder . TypeSpec ] | None = None ,
43+ ) -> ir .Graph :
2544 graph = ir .Graph (
2645 name = "test_model" ,
2746 inputs = [],
@@ -30,25 +49,29 @@ def _build(
3049 opset_imports = {"" : _default_opset_version },
3150 )
3251
33- onnx_model = ir .Model (graph = graph , ir_version = 10 )
52+ resolved_inputs = [_resolve_type_spec (t ) for t in input_types ]
53+ for i , ts in enumerate (resolved_inputs ):
54+ graph .inputs .append (ir .Value (name = f"input_{ i } " , type = ts .type , shape = ts .shape ))
3455
35- for i , input_type in enumerate (input_types ):
36- input_name = f"input_{ i } "
37- graph .inputs .append (ir .Value (name = input_name , type = input_type ))
56+ if trace_function is not None :
57+ graph_builder = builder .GraphBuilder (graph )
58+ outputs = trace_function (graph_builder .op , * graph .inputs )
59+ if not isinstance (outputs , Sequence ):
60+ outputs = [outputs ]
3861
39- graph_builder = builder . GraphBuilder ( graph )
40- outputs = trace_function ( graph_builder . op , * graph . inputs )
41- if not isinstance (outputs , Sequence ):
42- outputs = [ outputs ]
43- if len (outputs ) != len (output_types ):
44- raise ValueError ( f"Expected { len ( output_types ) } outputs, but got { len ( outputs ) } ." )
45- for output , output_type in zip (outputs , output_types ):
46- output . type = output_type . type # TODO: need merge_type method in ir.Value
47- output .merge_shapes (output_type .shape )
62+ if output_types is not None :
63+ resolved_outputs = [ _resolve_type_spec ( t ) for t in output_types ]
64+ if len (outputs ) != len ( resolved_outputs ):
65+ raise ValueError (
66+ f"Expected { len (resolved_outputs ) } outputs, but got { len (outputs ) } ."
67+ )
68+ for output , ts in zip (outputs , resolved_outputs ):
69+ output . type = ts . type
70+ output .merge_shapes (ts .shape )
4871
49- graph .outputs .extend (outputs )
72+ graph .outputs .extend (outputs )
5073
51- return onnx_model
74+ return graph
5275
5376
5477def _create_builder_with_inputs () -> tuple [builder .OpBuilder , ir .Value , ir .Value ]:
@@ -57,24 +80,7 @@ def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value
5780 Returns:
5881 A tuple of (op_builder, input_x, input_y).
5982 """
60- graph = ir .Graph (
61- name = "test_model" ,
62- inputs = [],
63- outputs = [],
64- nodes = [],
65- opset_imports = {"" : 23 },
66- )
67-
68- for i in range (2 ):
69- input_name = f"input_{ i } "
70- graph .inputs .append (
71- ir .Value (
72- name = input_name ,
73- type = ir .TensorType (ir .DataType .FLOAT ),
74- shape = ir .Shape ([2 , 3 , 4 ]),
75- )
76- )
77-
83+ graph = _build (input_types = [FLOAT [2 , 3 , 4 ], FLOAT [2 , 3 , 4 ]])
7884 graph_builder = builder .GraphBuilder (graph )
7985 x , y = graph .inputs
8086 return graph_builder .op , x , y
@@ -89,12 +95,11 @@ def _add_mul_add(op: builder.OpBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
8995 return z
9096
9197 float_2d = ir .TypeAndShape (ir .TensorType (ir .DataType .FLOAT ), ir .Shape ([3 , 4 ]))
92- model = _build (
93- _add_mul_add ,
98+ graph = _build (
9499 input_types = [float_2d , float_2d ],
100+ trace_function = _add_mul_add ,
95101 output_types = [float_2d ],
96102 )
97- graph = model .graph
98103 # Expect exactly 3 nodes: Add, Mul, Add
99104 op_types = [node .op_type for node in graph ]
100105 self .assertEqual (op_types , ["Add" , "Mul" , "Add" ])
@@ -121,12 +126,11 @@ def _add_with_custom_names(
121126 return z
122127
123128 float_2d = ir .TypeAndShape (ir .TensorType (ir .DataType .FLOAT ), ir .Shape ([3 , 4 ]))
124- model = _build (
125- _add_with_custom_names ,
129+ graph = _build (
126130 input_types = [float_2d , float_2d ],
131+ trace_function = _add_with_custom_names ,
127132 output_types = [float_2d ],
128133 )
129- graph = model .graph
130134
131135 # Verify that the nodes have outputs with the specified names
132136 nodes = list (graph )
@@ -207,12 +211,11 @@ def _ops_with_default_names(
207211 return z
208212
209213 float_2d = ir .TypeAndShape (ir .TensorType (ir .DataType .FLOAT ), ir .Shape ([3 , 4 ]))
210- model = _build (
211- _ops_with_default_names ,
214+ graph = _build (
212215 input_types = [float_2d , float_2d ],
216+ trace_function = _ops_with_default_names ,
213217 output_types = [float_2d ],
214218 )
215- graph = model .graph
216219
217220 # Verify the nodes use the new naming strategy
218221 nodes = list (graph )
@@ -1026,5 +1029,146 @@ def test_build_graph_custom_name(self):
10261029 self .assertEqual (graph .name , "loop_body" )
10271030
10281031
1032+ class PartitionInputsAttributesTest (unittest .TestCase ):
1033+ """Tests for GraphBuilder._partition_inputs_attributes."""
1034+
1035+ def test_unknown_op_passes_inputs_and_kwargs_through (self ):
1036+ """An unknown op has no schema, so inputs and kwargs pass through unchanged."""
1037+
1038+ def _dummy (op , x , y ):
1039+ return op .DummyOp (x , y , alpha = 1.0 )
1040+
1041+ graph = _build (
1042+ input_types = [FLOAT [3 , 4 ], FLOAT [3 , 4 ]],
1043+ trace_function = _dummy ,
1044+ )
1045+ x , y = graph .inputs
1046+ node = graph .node (0 )
1047+ self .assertEqual (node .op_type , "DummyOp" )
1048+ self .assertEqual (list (node .inputs ), [x , y ])
1049+ self .assertEqual (node .attributes ["alpha" ].as_float (), 1.0 )
1050+
1051+ def test_op_with_only_inputs (self ):
1052+ """Add has two inputs and no attributes."""
1053+
1054+ def _add (op , x , y ):
1055+ return op .Add (x , y )
1056+
1057+ graph = _build (
1058+ input_types = [FLOAT [3 , 4 ], FLOAT [3 , 4 ]],
1059+ trace_function = _add ,
1060+ )
1061+ x , y = graph .inputs
1062+ node = graph .node (0 )
1063+ self .assertEqual (node .op_type , "Add" )
1064+ self .assertEqual (list (node .inputs ), [x , y ])
1065+ self .assertEqual (len (node .attributes ), 0 )
1066+
1067+ def test_op_with_inputs_and_attributes_in_kwargs (self ):
1068+ """Gemm has 3 inputs (A, B, C) and attributes (alpha, beta, transA, transB)."""
1069+
1070+ def _gemm (op , a , b , c ):
1071+ return op .Gemm (a , b , c , alpha = 2.0 , transB = 1 )
1072+
1073+ graph = _build (
1074+ input_types = [FLOAT [3 , 4 ], FLOAT [4 , 5 ], FLOAT [3 , 5 ]],
1075+ trace_function = _gemm ,
1076+ )
1077+ a , b , c = graph .inputs
1078+ node = graph .node (0 )
1079+ self .assertEqual (node .op_type , "Gemm" )
1080+ self .assertEqual (list (node .inputs ), [a , b , c ])
1081+ self .assertEqual (node .attributes ["alpha" ].as_float (), 2.0 )
1082+ self .assertEqual (node .attributes ["transB" ].as_int (), 1 )
1083+
1084+ def test_op_with_optional_input_omitted (self ):
1085+ """Gemm's third input (C) is optional. Omitting it should work."""
1086+
1087+ def _gemm_no_c (op , a , b ):
1088+ return op .Gemm (a , b , alpha = 2.0 )
1089+
1090+ graph = _build (
1091+ input_types = [FLOAT [3 , 4 ], FLOAT [4 , 5 ]],
1092+ trace_function = _gemm_no_c ,
1093+ )
1094+ a , b = graph .inputs
1095+ node = graph .node (0 )
1096+ self .assertEqual (node .op_type , "Gemm" )
1097+ self .assertEqual (list (node .inputs ), [a , b ])
1098+ self .assertEqual (node .attributes ["alpha" ].as_float (), 2.0 )
1099+
1100+ def test_does_not_fill_attribute_defaults (self ):
1101+ """Attribute defaults should not be filled in (fill_defaults=False)."""
1102+
1103+ def _gemm_no_attrs (op , a , b ):
1104+ return op .Gemm (a , b )
1105+
1106+ graph = _build (
1107+ input_types = [FLOAT [3 , 4 ], FLOAT [4 , 5 ]],
1108+ trace_function = _gemm_no_attrs ,
1109+ )
1110+ node = graph .node (0 )
1111+ # alpha, beta, transA, transB all have defaults but should NOT appear
1112+ self .assertFalse (node .attributes )
1113+
1114+ def test_variadic_inputs_with_attribute (self ):
1115+ """Concat has variadic inputs and an axis attribute."""
1116+
1117+ def _concat (op , x , y , z ):
1118+ return op .Concat (x , y , z , axis = 0 )
1119+
1120+ graph = _build (
1121+ input_types = [FLOAT [3 , 4 ], FLOAT [3 , 4 ], FLOAT [3 , 4 ]],
1122+ trace_function = _concat ,
1123+ )
1124+ x , y , z = graph .inputs
1125+ node = graph .node (0 )
1126+ self .assertEqual (node .op_type , "Concat" )
1127+ self .assertEqual (list (node .inputs ), [x , y , z ])
1128+ self .assertEqual (node .attributes ["axis" ].as_int (), 0 )
1129+
1130+ def test_slice_kwargs_are_correctly_ordered_as_inputs (self ):
1131+ """Calling op.Slice with keyword arguments should place them in schema order."""
1132+
1133+ def _slice (op , data , starts , ends , axes , steps ):
1134+ # Pass optional inputs as kwargs in non-schema order
1135+ return op .Slice (data , ends = ends , steps = steps , starts = starts , axes = axes )
1136+
1137+ graph = _build (
1138+ input_types = [FLOAT [20 , 10 ], INT64 [2 ], INT64 [2 ], INT64 [2 ], INT64 [2 ]],
1139+ trace_function = _slice ,
1140+ )
1141+ data , starts , ends , axes , steps = graph .inputs
1142+
1143+ slice_node = graph .node (0 )
1144+ self .assertEqual (slice_node .op_type , "Slice" )
1145+ # Schema order: data, starts, ends, axes, steps
1146+ self .assertEqual (list (slice_node .inputs ), [data , starts , ends , axes , steps ])
1147+
1148+ def test_omitting_required_input_raises (self ):
1149+ """Omitting a required input should raise TypeError."""
1150+
1151+ def _add_missing_input (op , x ):
1152+ return op .Add (x )
1153+
1154+ with self .assertRaises (TypeError ):
1155+ _build (
1156+ input_types = [FLOAT [3 , 4 ]],
1157+ trace_function = _add_missing_input ,
1158+ )
1159+
1160+ def test_extra_inputs_raises (self ):
1161+ """Extra positional inputs beyond the schema should raise TypeError."""
1162+
1163+ def _add_extra_input (op , x , y , z ):
1164+ return op .Add (x , y , z )
1165+
1166+ with self .assertRaises (TypeError ):
1167+ _build (
1168+ input_types = [FLOAT [3 , 4 ], FLOAT [3 , 4 ], FLOAT [3 , 4 ]],
1169+ trace_function = _add_extra_input ,
1170+ )
1171+
1172+
10291173if __name__ == "__main__" :
10301174 unittest .main ()
0 commit comments