Skip to content

Commit 86b37c5

Browse files
gramalingamCopilot
andcommitted
Make opset_imports required in build_graph and build_function
Remove the default of {"":23} — callers must explicitly specify opset_imports. This avoids silent version assumptions. subgraph() is unchanged since it inherits from the parent graph. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: G Ramalingam <grama@microsoft.com>
1 parent 2208869 commit 86b37c5

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

onnxscript/_internal/builder.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def build_graph(
230230
inputs: Sequence[ir.Value | None],
231231
outputs: Sequence[ir.Value],
232232
*,
233-
opset_imports: dict[str, int] | None = None,
233+
opset_imports: dict[str, int],
234234
name: str = "subgraph",
235235
parent: GraphBuilder | None = None,
236236
) -> ir.Graph:
@@ -246,6 +246,7 @@ def build_graph(
246246
lambda op, x, y: op.Add(x, y),
247247
inputs=[make_value("x", FLOAT[3, 4]), make_value("y", FLOAT[3, 4])],
248248
outputs=[make_value("sum", FLOAT[3, 4])],
249+
opset_imports={"": 23},
249250
)
250251
251252
Args:
@@ -263,7 +264,7 @@ def build_graph(
263264
the expected outputs. After tracing, the name and type of each
264265
declared output are applied to the corresponding returned value.
265266
opset_imports: Opset version map for the subgraph (e.g.
266-
``{"": 23}``). Defaults to ``{"": 23}`` when *None*.
267+
``{"": 23}``).
267268
name: Name of the resulting :class:`ir.Graph`.
268269
parent: Optional parent :class:`GraphBuilder`. When provided, the
269270
sub-builder's ``_root`` points to the root builder of the parent,
@@ -276,9 +277,6 @@ def build_graph(
276277
passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
277278
a ``Scan`` or ``Loop`` node).
278279
"""
279-
if opset_imports is None:
280-
opset_imports = {"": 23}
281-
282280
trace_args, graph_inputs = _split_optional_inputs(inputs)
283281

284282
subgraph = ir.Graph(
@@ -325,7 +323,7 @@ def build_function(
325323
domain: str,
326324
name: str,
327325
attributes: Mapping[str, ir.Attr] | Sequence[ir.Attr] | None = None,
328-
opset_imports: dict[str, int] | None = None,
326+
opset_imports: dict[str, int],
329327
) -> ir.Function:
330328
"""Build an :class:`ir.Function` by tracing *trace_function*.
331329
@@ -340,6 +338,7 @@ def build_function(
340338
[make_value("x"), make_value("y")],
341339
domain="com.example",
342340
name="MyAdd",
341+
opset_imports={"": 23},
343342
)
344343
345344
Args:
@@ -358,15 +357,12 @@ def build_function(
358357
attributes: Function-level attributes. Accepts a
359358
:class:`Mapping` from name to :class:`ir.Attr`, a
360359
:class:`Sequence` of :class:`ir.Attr`, or ``None``.
361-
opset_imports: Opset version map. Defaults to ``{"": 23}``.
360+
opset_imports: Opset version map (e.g. ``{"": 23}``).
362361
363362
Returns:
364363
An :class:`ir.Function` with initializers automatically lifted to
365364
``Constant`` nodes.
366365
"""
367-
if opset_imports is None:
368-
opset_imports = {"": 23}
369-
370366
trace_args, graph_inputs = _split_optional_inputs(inputs)
371367

372368
graph = ir.Graph(

onnxscript/_internal/builder_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from onnxscript.onnx_types import DOUBLE, FLOAT, INT64
1616

1717
_default_opset_version = 23
18+
_opset = {"": _default_opset_version}
1819

1920
# Convenience alias for tests — creates an ir.Value from (name, TypeSpec).
2021
_input = builder.make_value
@@ -1238,6 +1239,7 @@ def test_build_graph_custom_name(self):
12381239
inputs=[_input("x", FLOAT[...])],
12391240
outputs=[_input("y", FLOAT[...])],
12401241
name="loop_body",
1242+
opset_imports=_opset,
12411243
)
12421244
self.assertEqual(graph.name, "loop_body")
12431245

@@ -1262,6 +1264,7 @@ def body(op, x):
12621264
inputs=[_input("x", FLOAT[3])],
12631265
outputs=[_input("y", FLOAT[3])],
12641266
parent=parent_builder,
1267+
opset_imports=_opset,
12651268
)
12661269

12671270
def test_subgraph_sets_parent_and_root(self):
@@ -1304,8 +1307,8 @@ def test_build_graph_inherits_parent_scope_stack(self):
13041307
inputs=[_input("x", FLOAT[3, 4])],
13051308
outputs=[_input("y", FLOAT[3, 4])],
13061309
parent=parent_builder,
1310+
opset_imports=_opset,
13071311
)
1308-
13091312
# The single node created inside the subgraph should carry the
13101313
# parent's scope prefix in its name and metadata.
13111314
node = subgraph.node(0)
@@ -1685,6 +1688,7 @@ def test_build_function_basic(self):
16851688
[_input("x", FLOAT[3, 4]), _input("y", FLOAT[3, 4])],
16861689
domain="com.test",
16871690
name="MyAdd",
1691+
opset_imports=_opset,
16881692
)
16891693
self.assertIsInstance(fn, ir.Function)
16901694
self.assertEqual(fn.domain, "com.test")
@@ -1705,6 +1709,7 @@ def body(op, x, y):
17051709
[_input("x"), _input("y")],
17061710
domain="com.test",
17071711
name="AddAndMul",
1712+
opset_imports=_opset,
17081713
)
17091714
self.assertEqual(len(fn.graph.outputs), 2)
17101715

@@ -1719,6 +1724,7 @@ def test_build_function_with_attributes(self):
17191724
ir.Attr("scale", ir.AttributeType.FLOAT, 0.5),
17201725
ir.Attr("mode", ir.AttributeType.STRING, "fast"),
17211726
],
1727+
opset_imports=_opset,
17221728
)
17231729
self.assertIn("scale", fn.attributes)
17241730
self.assertIn("mode", fn.attributes)
@@ -1732,6 +1738,7 @@ def test_build_function_attributes_as_dict(self):
17321738
domain="com.test",
17331739
name="DictAttr",
17341740
attributes={"scale": attr},
1741+
opset_imports=_opset,
17351742
)
17361743
self.assertIn("scale", fn.attributes)
17371744

@@ -1742,6 +1749,7 @@ def test_build_function_lifts_initializers(self):
17421749
[_input("x", FLOAT[3])],
17431750
domain="com.test",
17441751
name="WithLiteral",
1752+
opset_imports=_opset,
17451753
)
17461754
# No initializers in function body
17471755
self.assertEqual(len(fn.graph.initializers), 0)
@@ -1763,6 +1771,7 @@ def body(op, x, y, z):
17631771
[_input("x", FLOAT[3]), None, _input("z", FLOAT[3])],
17641772
domain="com.test",
17651773
name="OptionalInputs",
1774+
opset_imports=_opset,
17661775
)
17671776
# Graph has 3 inputs: x, a placeholder for the absent y, and z
17681777
self.assertEqual(len(fn.graph.inputs), 3)
@@ -1785,6 +1794,7 @@ def body(op, x):
17851794
[_input("x")],
17861795
domain="com.test",
17871796
name="AppendOutputs",
1797+
opset_imports=_opset,
17881798
)
17891799
self.assertEqual(len(fn.graph.outputs), 1)
17901800
self.assertEqual(fn.graph.outputs[0].name, "result")
@@ -1803,6 +1813,7 @@ def body(op, x):
18031813
[_input("x")],
18041814
domain="com.test",
18051815
name="MixedOutputs",
1816+
opset_imports=_opset,
18061817
)
18071818

18081819
def test_build_function_no_outputs_raises(self):
@@ -1818,6 +1829,7 @@ def body(op, x):
18181829
[_input("x")],
18191830
domain="com.test",
18201831
name="NoOutputs",
1832+
opset_imports=_opset,
18211833
)
18221834

18231835
def test_build_function_input_with_producer_raises(self):
@@ -1832,6 +1844,7 @@ def test_build_function_input_with_producer_raises(self):
18321844
[used_value],
18331845
domain="com.test",
18341846
name="BadInput",
1847+
opset_imports=_opset,
18351848
)
18361849

18371850
def test_build_function_custom_opset(self):
@@ -1852,6 +1865,7 @@ def test_build_function_no_parent_isolation(self):
18521865
[_input("x", FLOAT[3])],
18531866
domain="com.test",
18541867
name="Isolated",
1868+
opset_imports=_opset,
18551869
)
18561870
# After lifting, there should be a Constant node and no initializers
18571871
self.assertEqual(len(fn.graph.initializers), 0)

onnxscript/nn/_module_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def outer_fn(op, x):
131131
inputs=[make_value("x", FLOAT[3])],
132132
outputs=[make_value("y", FLOAT[3])],
133133
parent=op.builder,
134+
opset_imports={"": 23},
134135
)
135136
return op.Identity(x)
136137

0 commit comments

Comments
 (0)