Skip to content

Commit 7be2601

Browse files
Copilotjustinchuby
andauthored
Fix non-deterministic rewriter behavior by using ordered dict instead of set for output_nodes
Replace set[NodePattern] with dict[NodePattern, None] to preserve deterministic insertion order from the outputs sequence. Python sets have non-deterministic iteration order due to hash randomization, which caused the multi-output pattern matching to behave differently across runs. Fixes #2234 Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/50f46d7a-beee-47c4-9369-3c28417380d1 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 12d4645 commit 7be2601

2 files changed

Lines changed: 31 additions & 2 deletions

File tree

onnxscript/rewriter/_pattern_ir.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,11 @@ def __init__(
836836

837837
# Determine the output nodes of the pattern. These are a minimal set of nodes
838838
# whose backward-slices cover the entire pattern.
839-
output_nodes: set[NodePattern] = set()
839+
# Use a dict as an ordered set to preserve deterministic insertion order
840+
# from the outputs sequence. Using a plain set would cause non-deterministic
841+
# ordering due to Python's hash randomization, leading to non-deterministic
842+
# pattern matching behavior.
843+
output_nodes: dict[NodePattern, None] = {}
840844
covered: set[NodePattern] = set()
841845
choice_values_returned: set[ValuePattern] = set()
842846
covered_choice_values: set[ValuePattern] = set()
@@ -848,7 +852,7 @@ def __init__(
848852
if isinstance(value_pattern, NodeOutputPattern):
849853
candidate = value_pattern.producer()
850854
if candidate not in covered:
851-
output_nodes.add(candidate)
855+
output_nodes[candidate] = None
852856
_add_backward_slice(candidate, covered, covered_choice_values)
853857
elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)):
854858
choice_values_returned.add(value_pattern)

onnxscript/rewriter/_pattern_ir_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,31 @@ def node_checker(context, node):
7171
self.assertTrue(hasattr(producer, "_check"))
7272
self.assertIs(producer._check, node_checker)
7373

74+
def test_graph_pattern_output_nodes_have_deterministic_order(self):
75+
"""Test that GraphPattern.output_nodes preserves insertion order from outputs.
76+
77+
Regression test for https://github.com/microsoft/onnxscript/issues/2234.
78+
When output_nodes was built from a set, Python's hash randomization could
79+
cause non-deterministic ordering, leading to non-deterministic pattern
80+
matching behavior for multi-output patterns.
81+
"""
82+
opset_builder = _pattern_ir.OpsetPatternBuilder("")
83+
x = _pattern_ir.ValuePattern("x")
84+
# Create two distinct node patterns via two separate ops
85+
out_a = opset_builder.Relu(x, _outputs=["a"])
86+
out_b = opset_builder.Sigmoid(x, _outputs=["b"])
87+
outputs = [out_a, out_b]
88+
89+
# Build the graph pattern multiple times and check the order is always the same
90+
for _ in range(50):
91+
graph_pattern = _pattern_ir.GraphPattern(inputs=[x], outputs=outputs, nodes=[])
92+
node_op_ids = [n._op_identifier for n in graph_pattern.output_nodes]
93+
self.assertEqual(
94+
node_op_ids,
95+
[("", "Relu", ""), ("", "Sigmoid", "")],
96+
"output_nodes order must match the order of outputs",
97+
)
98+
7499

75100
if __name__ == "__main__":
76101
unittest.main()

0 commit comments

Comments
 (0)