Skip to content

Commit 14c6da4

Browse files
committed
Inherit scope stack
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent d220c95 commit 14c6da4

2 files changed

Lines changed: 30 additions & 0 deletions

File tree

onnxscript/_internal/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ def build_graph(
194194
subgraph.inputs.append(ir.Value(name=input_name, type=ts.type, shape=ts.shape))
195195

196196
sub_builder = GraphBuilder(subgraph, parent=parent)
197+
if parent is not None:
198+
sub_builder._scope_stack = list(parent._scope_stack)
197199
trace_outputs = trace_function(sub_builder.op, *subgraph.inputs)
198200
if not isinstance(trace_outputs, Sequence):
199201
trace_outputs = [trace_outputs]

onnxscript/_internal/builder_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,34 @@ def body(op, x):
11021102

11031103
parent_builder.subgraph(body, inputs=[FLOAT[3]], outputs=[FLOAT[3]])
11041104

1105+
def test_build_graph_inherits_parent_scope_stack(self):
1106+
"""build_graph copies the parent's scope stack so nodes in the subgraph carry scoped names."""
1107+
parent_graph = ir.Graph(
1108+
name="main",
1109+
inputs=[],
1110+
outputs=[],
1111+
nodes=[],
1112+
opset_imports={"": 23},
1113+
)
1114+
parent_builder = builder.GraphBuilder(parent_graph)
1115+
parent_builder.push_module("encoder", "Encoder")
1116+
parent_builder.push_module("layers.0", "TransformerBlock")
1117+
1118+
subgraph = builder.build_graph(
1119+
lambda op, x: op.Relu(x),
1120+
inputs={"x": FLOAT[3, 4]},
1121+
outputs={"y": FLOAT[3, 4]},
1122+
parent=parent_builder,
1123+
)
1124+
1125+
# The single node created inside the subgraph should carry the
1126+
# parent's scope prefix in its name and metadata.
1127+
node = subgraph.node(0)
1128+
self.assertIn("encoder", node.name)
1129+
self.assertIn("layers.0", node.name)
1130+
self.assertIn("encoder", node.metadata_props["namespace"])
1131+
self.assertIn("TransformerBlock", node.metadata_props["namespace"])
1132+
11051133
def test_root_graph_builder_is_its_own_root(self):
11061134
"""A top-level GraphBuilder has root == self."""
11071135
graph = ir.Graph(

0 commit comments

Comments
 (0)