Skip to content

Commit 0a3efad

Browse files
author
Justin Chu
committed
Fix OpBuilder to properly extract _domain, _version, _outputs from kwargs
OpBuilder._call_op was inserting _domain, _version into the kwargs dict, but GraphBuilder.call_op expects domain, version, outputs as separate keyword arguments. This caused them to be treated as node attributes, breaking custom domain handling, schema lookup, type inference, shape inference, and output naming. Changes: - OpBuilder._call_op: pop _domain, _version, _outputs from kwargs and pass as separate keyword args to call_op - Remove _prefix from GraphBuilder.call and OpBuilder.call (only call_inline needs it) - Update test to use push_module/pop_module instead of _prefix on call
1 parent 5a9e00b commit 0a3efad

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

onnxscript/_internal/builder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,11 +721,12 @@ def version(self) -> int | None:
721721
return self._version
722722

723723
def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
724-
if "_domain" not in kwargs:
725-
kwargs["_domain"] = self._domain
726-
if self._version is not None and "_version" not in kwargs:
727-
kwargs["_version"] = self._version
728-
return self._builder.call_op(op_type, inputs, kwargs)
724+
domain = kwargs.pop("_domain", self._domain)
725+
version = kwargs.pop("_version", self._version)
726+
outputs = kwargs.pop("_outputs", 1)
727+
return self._builder.call_op(
728+
op_type, inputs, kwargs, domain=domain, version=version, outputs=outputs
729+
)
729730

730731
def __getattr__(self, op_type: str) -> Callable:
731732
return lambda *args, **kwargs: self._call_op(op_type, args, kwargs)

onnxscript/_internal/builder_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,8 @@ def add_mul(X, Y):
925925
self.assertEqual(len(nodes), 1)
926926
self.assertEqual(nodes[0].op_type, "add_mul")
927927

928-
def test_call_with_prefix_option(self):
929-
"""Test that GraphBuilder.call respects the _prefix option for hierarchical naming."""
928+
def test_call_with_push_module_prefix(self):
929+
"""Test that GraphBuilder.call respects push_module for hierarchical naming."""
930930
op, x, y = _create_builder_with_inputs()
931931

932932
@script(default_opset=op)
@@ -935,7 +935,9 @@ def mul_add_relu(X, Y):
935935
tmp = tmp + X
936936
return op.Relu(tmp)
937937

938-
result = op.call(mul_add_relu, x, y, _prefix="layer1")
938+
op.builder.push_module("layer1")
939+
result = op.call(mul_add_relu, x, y)
940+
op.builder.pop_module()
939941

940942
nodes = list(op.builder.graph)
941943
self.assertEqual(len(nodes), 1)

0 commit comments

Comments
 (0)