Skip to content

Commit c75a37d

Browse files
committed
move subgraph initializers to main graph when inlining If branches
1 parent 12234f8 commit c75a37d

File tree

2 files changed

+112
-37
lines changed

2 files changed

+112
-37
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
3-
43
# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files.
5-
64
from __future__ import annotations
75

86
__all__ = [
@@ -81,8 +79,6 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
8179
# The API below works only for non-control-flow ops (ops without any graph-attributes).
8280
# This currently used ONNX's reference implementation. But we could also
8381
# use ORT's implementation if we want to.
84-
85-
8682
def _process_constant_node(node: ir.Node) -> None:
8783
"""Sets const_value of output value of a Constant op node."""
8884
if not _is_onnx_op(node, "Constant"):
@@ -126,7 +122,6 @@ def _process_constant_node(node: ir.Node) -> None:
126122

127123
def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None:
128124
"""Performs basic constant propagation for a sequence of nodes.
129-
130125
Just marks the output values of Constant op nodes with their const_value.
131126
"""
132127
for node in nodes:
@@ -210,12 +205,10 @@ def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
210205

211206
# The "partial evaluators" below are non-standard evaluators. They are used to perform
212207
# partial evaluation and/or static program analysis (abstract interpretation).
213-
214208
# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns
215209
# a Replacement for the node or None (if no replacement is needed). It may also return just
216210
# the ir.Value or ir.Values to replace the output values of the node, when the new nodes
217211
# can be inferred from the RewriterContext used to build the new nodes.
218-
219212
RewriterContext = _tape.Builder
220213
ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None]
221214
PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue]
@@ -471,7 +464,6 @@ def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnVa
471464
@register("Reshape")
472465
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
473466
"""Replace a Reshape node by Identity when applicable.
474-
475467
Also propagate symbolic shape values.
476468
"""
477469
input = _get_input(node, 0)
@@ -562,6 +554,27 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
562554
return op.Constant(value_int=size)
563555

564556

557+
def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None:
558+
"""Move all initializers from src graph to dst graph, ensuring name uniqueness.
559+
When an If branch is inlined into the main graph, the branch subgraph may
560+
hold initializers (e.g. a constant axes tensor for a Squeeze node) that were
561+
folded in a prior pass. Those initializers must be migrated to the main graph
562+
so that the inlined nodes can still reference them; failing to do so leaves the
563+
references dangling and produces an invalid model.
564+
"""
565+
counter: dict[str, int] = {}
566+
for name in list(src.initializers):
567+
initializer = src.initializers.pop(name)
568+
# Ensure name uniqueness in the destination graph
569+
new_name = name
570+
while new_name in dst.initializers:
571+
counter[name] = counter.get(name, 0) + 1
572+
new_name = f"{name}_{counter[name]}"
573+
if new_name != name:
574+
initializer.name = new_name
575+
dst.register_initializer(initializer)
576+
577+
565578
@register("If")
566579
def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
567580
cond_input = _get_input(node, 0)
@@ -586,7 +599,6 @@ def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
586599
if actual is not None
587600
}
588601
# TODO: Extend renaming to intermediate values.
589-
590602
def rename(name):
591603
return renamings.get(name, name)
592604

@@ -599,6 +611,15 @@ def rename(name):
599611
# Avoid name collision.
600612
sub_node.name = f"{node.name}_{sub_node.name}"
601613

614+
# Move initializers from the subgraph to the main graph to avoid losing them.
615+
# When the If branch was processed in a prior constant-folding pass, any
616+
# constants inside the branch (e.g. the 'axes' tensor for a Squeeze node)
617+
# may have been folded into subgraph initializers. Without this step those
618+
# initializers would be orphaned once the branch nodes are inlined here.
619+
main_graph = node.graph
620+
if main_graph is not None:
621+
_move_initializers_to_graph(graph, main_graph)
622+
602623
# TODO: we should handle initializers as well!
603624
return Replacement(formal_outs, graph_nodes)
604625
return None
@@ -787,7 +808,6 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu
787808
@register("SplitToSequence")
788809
def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
789810
"""Rewriting pattern.
790-
791811
From
792812
793813
splits = onnx::SplitToSequence(input, split, axis=axis)
@@ -965,14 +985,13 @@ def _record_contributing_values(original_node: ir.Node, replacement: Replacement
965985

966986
class FoldConstantsPass(ir.passes.InPlacePass):
967987
"""A pass that folds constant expressions in the model.
968-
969988
Attributes:
970989
shape_inference: Whether to perform shape inference.
971990
input_size_limit: Maximum size of input tensors to fold.
972991
output_size_limit: Maximum size of output tensors to fold.
973992
should_fold: An optional function that takes a node and returns True if
974993
the node should be considered for folding.
975-
The function should return True/False value to indicate if this particular
994+
The function should return True/False value to indicate if this particular
976995
node should be folded, or None to use the default folding rules.
977996
"""
978997

@@ -1201,7 +1220,6 @@ def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None:
12011220
node.domain,
12021221
node.op_type,
12031222
)
1204-
12051223
return None
12061224

12071225
if _is_non_deterministic_op(node):
@@ -1240,8 +1258,7 @@ def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None:
12401258
for op_type in DEFAULT_CONSTANT_FOLD_BLACKLIST:
12411259
if _is_onnx_op(node, op_type):
12421260
logger.info(
1243-
"Skipping constant folding for node %r because "
1244-
"%s is preserved by default",
1261+
"Skipping constant folding for node %r because %s is preserved by default",
12451262
node.name,
12461263
op_type,
12471264
)
@@ -1464,12 +1481,11 @@ def fold_constants(
14641481
14651482
Returns:
14661483
An instance of `FoldConstantsResult`.
1467-
14681484
"""
14691485
folder_pass = FoldConstantsPass(
14701486
shape_inference=onnx_shape_inference,
14711487
input_size_limit=input_size_limit,
14721488
output_size_limit=output_size_limit,
14731489
should_fold=should_fold,
14741490
)
1475-
return folder_pass(model) # type: ignore[return-value]
1491+
return folder_pass(model)

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,76 @@
1414

1515

1616
class FoldConstantsTest(unittest.TestCase):
17+
def test_fold_if_cond_with_subgraph_initializer(self):
18+
model = ir.from_onnx_text("""
19+
<ir_version: 7, opset_import: [ "" : 17]>
20+
agraph (float[16, 16] x, bool cond) => (float[16, 16] z) {
21+
two = Constant <value_float=2.0> ()
22+
three = Constant <value_float=3.0> ()
23+
z = If (cond) <
24+
then_branch = then_graph () => (then_z) {
25+
temp = Add (two, three)
26+
then_z = Mul (temp, x)
27+
},
28+
else_branch = else_graph () => (else_z) {
29+
else_z = Identity (x)
30+
}
31+
>
32+
}
33+
""")
34+
35+
# Pass 1: fold Add(2.0, 3.0) into a subgraph initializer called 'temp'.
36+
# The If condition is still non-constant so the branch is NOT inlined yet.
37+
_constant_folding.fold_constants(model)
38+
optimizer.remove_unused_nodes(model)
39+
if_node = next(n for n in model.graph if n.op_type == "If")
40+
then_branch = if_node.attributes["then_branch"].as_graph()
41+
self.assertIn("temp", then_branch.initializers)
42+
self.assertNotIn("temp", model.graph.initializers)
43+
44+
# Make the condition a known True constant to trigger branch inlining.
45+
const_true = ir.Value(name="const_true")
46+
const_true.const_value = ir.Tensor(np.array(True))
47+
if_node.replace_input_with(0, const_true)
48+
49+
# Pass 2: inline the If branch.
50+
# 'temp' must be migrated from the subgraph to the main graph.
51+
_constant_folding.fold_constants(model)
52+
optimizer.remove_unused_nodes(model)
53+
self.assertIn("temp", model.graph.initializers)
54+
onnx.checker.check_model(ir.serde.serialize_model(model))
55+
56+
def test_fold_if_cond_with_subgraph_initializer_name_collision(self):
57+
"""Subgraph initializer names that clash with main-graph names get a unique suffix."""
58+
model = ir.from_onnx_text("""
59+
<ir_version: 7, opset_import: [ "" : 17]>
60+
agraph (float[1, 4] x, bool cond) => (float[4] z) {
61+
axes_val = Constant <value_ints=[0]> ()
62+
z = If (cond) <
63+
then_branch = then_branch_graph () => (then_z) {
64+
axes_val_inner = Constant <value_ints=[0]> ()
65+
then_z = Squeeze (x, axes_val_inner)
66+
},
67+
else_branch = else_branch_graph () => (else_z) {
68+
else_z = Squeeze (x, axes_val)
69+
}
70+
>
71+
}
72+
""")
73+
74+
_constant_folding.fold_constants(model)
75+
optimizer.remove_unused_nodes(model)
76+
77+
if_node = next(n for n in model.graph if n.op_type == "If")
78+
const_true = ir.Value(name="const_true_collision")
79+
const_true.const_value = ir.Tensor(np.array(True))
80+
if_node.replace_input_with(0, const_true)
81+
82+
# Must not crash or silently overwrite on name collision.
83+
_constant_folding.fold_constants(model)
84+
optimizer.remove_unused_nodes(model)
85+
onnx.checker.check_model(ir.serde.serialize_model(model))
86+
1787
def _fold(
1888
self,
1989
model: ir.Model | str,
@@ -236,9 +306,7 @@ def test_shape_inference(self):
236306
self.assertEqual(len(optimized.graph), 1)
237307
self.assertIn("C", optimized.graph.initializers)
238308

239-
def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(
240-
self,
241-
):
309+
def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(self):
242310
model = """
243311
<
244312
ir_version: 8,
@@ -260,8 +328,7 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_
260328

261329
# TODO: There is an unrelated limitation that `symbolic_value` is not
262330
# utilized when the value is only referenced by graph output.
263-
# E.g., the following test model will not have this optimization
264-
# applied.
331+
# E.g., the following test model will not have this optimization applied.
265332
#
266333
# <
267334
# ir_version: 8,
@@ -284,9 +351,7 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_
284351
self.assertEqual(len(optimized.graph[-2].outputs), 4)
285352
self.assertEqual(optimized.graph[-2].op_type, "Split")
286353

287-
def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split(
288-
self,
289-
):
354+
def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split(self):
290355
model = """
291356
<
292357
ir_version: 8,
@@ -309,9 +374,7 @@ def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_sp
309374
self.assertEqual(len(optimized.graph[-2].outputs), 3)
310375
self.assertEqual(optimized.graph[-2].op_type, "Split")
311376

312-
def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze(
313-
self,
314-
):
377+
def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze(self):
315378
model = """
316379
<
317380
ir_version: 8,
@@ -334,9 +397,7 @@ def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_
334397
self.assertEqual(optimized.graph[1].op_type, "Split")
335398
self.assertEqual(len([n for n in optimized.graph if n.op_type == "Squeeze"]), 3)
336399

337-
def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0(
338-
self,
339-
):
400+
def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0(self):
340401
model = """
341402
<
342403
ir_version: 8,
@@ -352,9 +413,7 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0(
352413
self.assertEqual(len(optimized.graph), 3)
353414
self.assertEqual(optimized.graph[2].op_type, "Concat")
354415

355-
def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
356-
self,
357-
):
416+
def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(self):
358417
model = """
359418
<
360419
ir_version: 8,
@@ -736,8 +795,8 @@ def test_multi_graph_identity_output_preserves_output_name(self):
736795
self.assertEqual([input.name for input in optimized.graph.inputs], ["x"])
737796

738797
# This should not be constant-foldable as the constant references an
739-
# attribute and thus the shape cannot be resolved. At the same time it
740-
# should not fail due to the attribute value being None in
798+
# attribute and thus the shape cannot be resolved.
799+
# At the same time it should not fail due to the attribute value being None in
741800
# _process_constant_node
742801
def test_attribute_reference(self):
743802
model = """
@@ -798,4 +857,4 @@ def test_initializer_as_graph_output_is_not_removed(self):
798857

799858

800859
if __name__ == "__main__":
801-
unittest.main()
860+
unittest.main()

0 commit comments

Comments
 (0)