Skip to content

Commit 5dac5bc

Browse files
authored
Merge branch 'main' into dependabot/github_actions/codecov/codecov-action-6
2 parents 88ab550 + 4291ff2 commit 5dac5bc

File tree

6 files changed

+215
-6
lines changed

6 files changed

+215
-6
lines changed

onnxscript/_internal/builder.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def build_graph(
133133
*,
134134
opset_imports: dict[str, int] | None = None,
135135
name: str = "subgraph",
136+
parent: GraphBuilder | None = None,
136137
) -> ir.Graph:
137138
"""Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.
138139
@@ -165,6 +166,10 @@ def build_graph(
165166
opset_imports: Opset version map for the subgraph (e.g.
166167
``{"": 23}``). Defaults to ``{"": 23}`` when *None*.
167168
name: Name of the resulting :class:`ir.Graph`.
169+
parent: Optional parent :class:`GraphBuilder`. When provided, the
170+
sub-builder's ``_root`` points to the root builder of the parent,
171+
so that :meth:`Parameter._realize` registers initializers in the
172+
root (main) graph rather than the subgraph.
168173
169174
Returns:
170175
An :class:`ir.Graph` whose inputs and outputs are populated and whose
@@ -188,7 +193,9 @@ def build_graph(
188193
for input_name, ts in resolved_inputs:
189194
subgraph.inputs.append(ir.Value(name=input_name, type=ts.type, shape=ts.shape))
190195

191-
sub_builder = GraphBuilder(subgraph)
196+
sub_builder = GraphBuilder(subgraph, parent=parent)
197+
if parent is not None:
198+
sub_builder._scope_stack = list(parent._scope_stack)
192199
trace_outputs = trace_function(sub_builder.op, *subgraph.inputs)
193200
if not isinstance(trace_outputs, Sequence):
194201
trace_outputs = [trace_outputs]
@@ -209,8 +216,10 @@ def build_graph(
209216
class GraphBuilder:
210217
"""Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference."""
211218

212-
def __init__(self, graph: ir.Graph) -> None:
219+
def __init__(self, graph: ir.Graph, parent: GraphBuilder | None = None) -> None:
213220
self._graph = graph
221+
self._parent = parent
222+
self._root: GraphBuilder = parent._root if parent is not None else self
214223

215224
# Get the opset version for "" (default domain) from the graph
216225
if "" not in graph.opset_imports:
@@ -238,6 +247,16 @@ def opset(self, domain: str, version: int = 1) -> OpBuilder:
238247
def op(self) -> OpBuilder:
239248
return self._op_builder
240249

250+
@property
251+
def parent(self) -> GraphBuilder | None:
252+
"""The parent builder, or None for a top-level builder."""
253+
return self._parent
254+
255+
@property
256+
def root(self) -> GraphBuilder:
257+
"""The root (top-level) builder in the parent chain."""
258+
return self._root
259+
241260
@property
242261
def graph(self) -> ir.Graph:
243262
return self._graph
@@ -502,6 +521,7 @@ def subgraph(
502521
outputs,
503522
opset_imports=dict(self._graph.opset_imports),
504523
name=name,
524+
parent=self,
505525
)
506526

507527
def call_op(

onnxscript/_internal/builder_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,88 @@ def test_build_graph_custom_name(self):
10611061
)
10621062
self.assertEqual(graph.name, "loop_body")
10631063

1064+
def test_build_graph_with_parent(self):
1065+
"""build_graph with parent sets root on the sub-builder."""
1066+
parent_graph = ir.Graph(
1067+
name="main",
1068+
inputs=[],
1069+
outputs=[],
1070+
nodes=[],
1071+
opset_imports={"": 23},
1072+
)
1073+
parent_builder = builder.GraphBuilder(parent_graph)
1074+
1075+
def body(op, x):
1076+
self.assertIs(op.builder.parent, parent_builder)
1077+
self.assertIs(op.builder.root, parent_builder)
1078+
return op.Identity(x)
1079+
1080+
builder.build_graph(
1081+
body,
1082+
inputs=[FLOAT[3]],
1083+
outputs=[FLOAT[3]],
1084+
parent=parent_builder,
1085+
)
1086+
1087+
def test_subgraph_sets_parent_and_root(self):
1088+
"""GraphBuilder.subgraph() sets parent=self on the sub-builder."""
1089+
parent_graph = ir.Graph(
1090+
name="main",
1091+
inputs=[],
1092+
outputs=[],
1093+
nodes=[],
1094+
opset_imports={"": 23},
1095+
)
1096+
parent_builder = builder.GraphBuilder(parent_graph)
1097+
1098+
def body(op, x):
1099+
self.assertIs(op.builder.parent, parent_builder)
1100+
self.assertIs(op.builder.root, parent_builder)
1101+
return op.Identity(x)
1102+
1103+
parent_builder.subgraph(body, inputs=[FLOAT[3]], outputs=[FLOAT[3]])
1104+
1105+
def test_build_graph_inherits_parent_scope_stack(self):
1106+
"""build_graph copies the parent's scope stack so nodes in the subgraph carry scoped names."""
1107+
parent_graph = ir.Graph(
1108+
name="main",
1109+
inputs=[],
1110+
outputs=[],
1111+
nodes=[],
1112+
opset_imports={"": 23},
1113+
)
1114+
parent_builder = builder.GraphBuilder(parent_graph)
1115+
parent_builder.push_module("encoder", "Encoder")
1116+
parent_builder.push_module("layers.0", "TransformerBlock")
1117+
1118+
subgraph = builder.build_graph(
1119+
lambda op, x: op.Relu(x),
1120+
inputs={"x": FLOAT[3, 4]},
1121+
outputs={"y": FLOAT[3, 4]},
1122+
parent=parent_builder,
1123+
)
1124+
1125+
# The single node created inside the subgraph should carry the
1126+
# parent's scope prefix in its name and metadata.
1127+
node = subgraph.node(0)
1128+
self.assertIn("encoder", node.name)
1129+
self.assertIn("layers.0", node.name)
1130+
self.assertIn("encoder", node.metadata_props["namespace"])
1131+
self.assertIn("TransformerBlock", node.metadata_props["namespace"])
1132+
1133+
def test_root_graph_builder_is_its_own_root(self):
1134+
"""A top-level GraphBuilder has root == self."""
1135+
graph = ir.Graph(
1136+
name="main",
1137+
inputs=[],
1138+
outputs=[],
1139+
nodes=[],
1140+
opset_imports={"": 23},
1141+
)
1142+
gb = builder.GraphBuilder(graph)
1143+
self.assertIs(gb.root, gb)
1144+
self.assertIsNone(gb.parent)
1145+
10641146

10651147
class PartitionInputsAttributesTest(unittest.TestCase):
10661148
"""Tests for GraphBuilder._partition_inputs_attributes."""

onnxscript/nn/_module_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,77 @@ def test_realize_qualifies_name(self):
7272
self.assertEqual(value.name, "layer1.bias")
7373
self.assertIn("layer1.bias", graph.initializers)
7474

75+
def test_realize_in_subgraph_registers_in_root(self):
76+
"""Parameter realized inside a subgraph builder is stored in the root graph."""
77+
from onnxscript._internal.builder import GraphBuilder
78+
from onnxscript.onnx_types import FLOAT
79+
80+
root_graph = ir.Graph(
81+
name="main",
82+
inputs=[],
83+
outputs=[],
84+
nodes=[],
85+
opset_imports={"": 23},
86+
)
87+
root_builder = GraphBuilder(root_graph)
88+
89+
p = Parameter([3, 4], name="weight")
90+
91+
def body_fn(op, x):
92+
# Realize param inside a sub-builder context
93+
p._realize(op.builder) # pylint: disable=protected-access
94+
return op.Add(x, x)
95+
96+
_sub_graph = root_builder.subgraph(
97+
body_fn,
98+
inputs=[FLOAT[3, 4]],
99+
outputs=[FLOAT[3, 4]],
100+
)
101+
# Parameter should be in the ROOT graph's initializers, not the subgraph's
102+
self.assertIn("weight", root_graph.initializers)
103+
self.assertIs(root_graph.initializers["weight"], p)
104+
# The subgraph should NOT have the initializer
105+
self.assertNotIn("weight", _sub_graph.initializers)
106+
107+
def test_realize_in_nested_subgraph_registers_in_root(self):
108+
"""Parameter realized in a doubly-nested subgraph goes to the root graph."""
109+
from onnxscript._internal.builder import GraphBuilder, build_graph
110+
from onnxscript.onnx_types import FLOAT
111+
112+
root_graph = ir.Graph(
113+
name="main",
114+
inputs=[],
115+
outputs=[],
116+
nodes=[],
117+
opset_imports={"": 23},
118+
)
119+
root_builder = GraphBuilder(root_graph)
120+
121+
p = Parameter([3], name="bias")
122+
123+
def inner_fn(op, x):
124+
p._realize(op.builder) # pylint: disable=protected-access
125+
return op.Identity(x)
126+
127+
def outer_fn(op, x):
128+
# Build a nested subgraph
129+
build_graph(
130+
inner_fn,
131+
inputs=[FLOAT[3]],
132+
outputs=[FLOAT[3]],
133+
parent=op.builder,
134+
)
135+
return op.Identity(x)
136+
137+
root_builder.subgraph(
138+
outer_fn,
139+
inputs=[FLOAT[3]],
140+
outputs=[FLOAT[3]],
141+
)
142+
# Even through two levels of nesting, param ends up in root
143+
self.assertIn("bias", root_graph.initializers)
144+
self.assertIs(root_graph.initializers["bias"], p)
145+
75146

76147
class ModuleBasicTest(unittest.TestCase):
77148
def test_parameter_auto_registration(self):

onnxscript/nn/_parameter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def dtype(self) -> ir.DataType | None: # type: ignore[override]
6060
def _realize(self, builder: _builder.GraphBuilder) -> Parameter:
6161
"""Qualify the name and register as a graph initializer.
6262
63+
Uses the builder's *root* graph builder to qualify the name and
64+
register the initializer. When the builder is a sub-builder (e.g.
65+
for a Scan body), this ensures the parameter is stored in the
66+
main graph — making it visible as an implicit input to the
67+
subgraph rather than incorrectly placed inside it.
68+
6369
Uses direct assignment to ``graph.initializers[...]`` to skip the
6470
const_value check. Idempotent: subsequent calls are no-ops.
6571
"""
@@ -73,8 +79,9 @@ def _realize(self, builder: _builder.GraphBuilder) -> Parameter:
7379
"Ensure the Parameter is attached to a Module attribute or otherwise "
7480
"initialized with a name before realization."
7581
)
76-
self_name = self.name = builder._qualify_initializer_name(self_name) # pylint: disable=protected-access
77-
builder.graph.initializers[self_name] = self
82+
root = builder.root
83+
self_name = self.name = root._qualify_initializer_name(self_name) # pylint: disable=protected-access
84+
root.graph.initializers[self_name] = self
7885
self._realized = True
7986
return self
8087

onnxscript/rewriter/_pattern_ir.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,11 @@ def __init__(
836836

837837
# Determine the output nodes of the pattern. These are a minimal set of nodes
838838
# whose backward-slices cover the entire pattern.
839-
output_nodes: set[NodePattern] = set()
839+
# Use a dict as an ordered set to preserve deterministic insertion order
840+
# from the outputs sequence. Using a plain set would cause non-deterministic
841+
# ordering due to Python's hash randomization, leading to non-deterministic
842+
# pattern matching behavior.
843+
output_nodes: dict[NodePattern, None] = {}
840844
covered: set[NodePattern] = set()
841845
choice_values_returned: set[ValuePattern] = set()
842846
covered_choice_values: set[ValuePattern] = set()
@@ -848,7 +852,7 @@ def __init__(
848852
if isinstance(value_pattern, NodeOutputPattern):
849853
candidate = value_pattern.producer()
850854
if candidate not in covered:
851-
output_nodes.add(candidate)
855+
output_nodes[candidate] = None
852856
_add_backward_slice(candidate, covered, covered_choice_values)
853857
elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)):
854858
choice_values_returned.add(value_pattern)

onnxscript/rewriter/_pattern_ir_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,31 @@ def node_checker(context, node):
7171
self.assertTrue(hasattr(producer, "_check"))
7272
self.assertIs(producer._check, node_checker)
7373

74+
def test_graph_pattern_output_nodes_have_deterministic_order(self):
75+
"""Test that GraphPattern.output_nodes preserves insertion order from outputs.
76+
77+
Regression test for https://github.com/microsoft/onnxscript/issues/2234.
78+
When output_nodes was built from a set, Python's hash randomization could
79+
cause non-deterministic ordering, leading to non-deterministic pattern
80+
matching behavior for multi-output patterns.
81+
"""
82+
opset_builder = _pattern_ir.OpsetPatternBuilder("")
83+
x = _pattern_ir.ValuePattern("x")
84+
# Create two distinct node patterns via two separate ops
85+
out_a = opset_builder.Relu(x, _outputs=["a"])
86+
out_b = opset_builder.Sigmoid(x, _outputs=["b"])
87+
outputs = [out_a, out_b]
88+
89+
# Build the graph pattern multiple times and check the order is always the same
90+
for _ in range(50):
91+
graph_pattern = _pattern_ir.GraphPattern(inputs=[x], outputs=outputs, nodes=[])
92+
node_op_ids = [n._op_identifier for n in graph_pattern.output_nodes]
93+
self.assertEqual(
94+
node_op_ids,
95+
[("", "Relu", ""), ("", "Sigmoid", "")],
96+
"output_nodes order must match the order of outputs",
97+
)
98+
7499

75100
if __name__ == "__main__":
76101
unittest.main()

0 commit comments

Comments
 (0)