Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions onnxscript/rewriter/_pattern_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,11 @@ def __init__(

# Determine the output nodes of the pattern. These are a minimal set of nodes
# whose backward-slices cover the entire pattern.
output_nodes: set[NodePattern] = set()
# Use a dict as an ordered set to preserve deterministic insertion order
# from the outputs sequence. Using a plain set would cause non-deterministic
# ordering due to Python's hash randomization, leading to non-deterministic
# pattern matching behavior.
output_nodes: dict[NodePattern, None] = {}
covered: set[NodePattern] = set()
choice_values_returned: set[ValuePattern] = set()
covered_choice_values: set[ValuePattern] = set()
Expand All @@ -848,7 +852,7 @@ def __init__(
if isinstance(value_pattern, NodeOutputPattern):
candidate = value_pattern.producer()
if candidate not in covered:
output_nodes.add(candidate)
output_nodes[candidate] = None
_add_backward_slice(candidate, covered, covered_choice_values)
elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)):
choice_values_returned.add(value_pattern)
Expand Down
25 changes: 25 additions & 0 deletions onnxscript/rewriter/_pattern_ir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def node_checker(context, node):
self.assertTrue(hasattr(producer, "_check"))
self.assertIs(producer._check, node_checker)

def test_graph_pattern_output_nodes_have_deterministic_order(self):
"""Test that GraphPattern.output_nodes preserves insertion order from outputs.

Regression test for https://github.com/microsoft/onnxscript/issues/2234.
When output_nodes was built from a set, Python's hash randomization could
cause non-deterministic ordering, leading to non-deterministic pattern
matching behavior for multi-output patterns.
"""
opset_builder = _pattern_ir.OpsetPatternBuilder("")
x = _pattern_ir.ValuePattern("x")
# Create two distinct node patterns via two separate ops
out_a = opset_builder.Relu(x, _outputs=["a"])
out_b = opset_builder.Sigmoid(x, _outputs=["b"])
outputs = [out_a, out_b]

# Build the graph pattern multiple times and check the order is always the same
for _ in range(50):
graph_pattern = _pattern_ir.GraphPattern(inputs=[x], outputs=outputs, nodes=[])
node_op_ids = [n._op_identifier for n in graph_pattern.output_nodes]
self.assertEqual(
node_op_ids,
[("", "Relu", ""), ("", "Sigmoid", "")],
"output_nodes order must match the order of outputs",
)


if __name__ == "__main__":
unittest.main()
Loading