Skip to content

Commit c580339

Browse files
Copilotjustinchuby
andcommitted
Fix: move subgraph initializers to main graph when inlining If node branches
When constant folding inlines an If node's branch (because the condition is constant), initializers from the subgraph were not moved to the main graph. This caused them to be lost (e.g., the 'axes' initializer for a Squeeze node inside an If branch). Fix: add _move_initializers_to_graph() helper that moves initializers from the subgraph to the main graph with name uniqueness handling, and call it from if_op() before returning the Replacement. Adds test: test_fold_if_cond_with_subgraph_initializer Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent c3afbce commit c580339

File tree

3 files changed

+88
-33
lines changed

3 files changed

+88
-33
lines changed

=0.1.16

Lines changed: 0 additions & 32 deletions
This file was deleted.

onnxscript/optimizer/_constant_folding.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,21 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
561561
return op.Constant(value_int=size)
562562

563563

564+
def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None:
565+
"""Move all initializers from src graph to dst graph, ensuring name uniqueness."""
566+
counter: dict[str, int] = {}
567+
for name in list(src.initializers):
568+
initializer = src.initializers.pop(name)
569+
# Ensure name uniqueness in the destination graph
570+
new_name = name
571+
while new_name in dst.initializers:
572+
counter[name] = counter.get(name, 0) + 1
573+
new_name = f"{name}_{counter[name]}"
574+
if new_name != name:
575+
initializer.name = new_name
576+
dst.register_initializer(initializer)
577+
578+
564579
@register("If")
565580
def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
566581
cond_input = _get_input(node, 0)
@@ -598,7 +613,11 @@ def rename(name):
598613
# Avoid name collision.
599614
sub_node.name = f"{node.name}_{sub_node.name}"
600615

601-
# TODO: we should handle initializers as well!
616+
# Move initializers from the subgraph to the main graph to avoid losing them.
617+
main_graph = node.graph
618+
if main_graph is not None:
619+
_move_initializers_to_graph(graph, main_graph)
620+
602621
return Replacement(formal_outs, graph_nodes)
603622
return None
604623

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import numpy as np
88
import onnx
9+
import onnx.helper
10+
import onnx.numpy_helper
911
import parameterized
1012

1113
import onnxscript.optimizer as optimizer
@@ -130,6 +132,72 @@ def test_fold_if_cond(self):
130132
self.assertEqual(optimized.graph[0].outputs[0].name, "z")
131133
self.assertEqual(optimized.graph[0].op_type, "Mul")
132134

135+
def _make_if_model_with_subgraph_initializer(self) -> ir.Model:
136+
"""Build a model where the then_branch of an If node has an initializer."""
137+
# Build then_branch: Squeeze(x, axes) where axes=[0] is an initializer
138+
axes_init = onnx.numpy_helper.from_array(np.array([0], dtype=np.int64), name="axes")
139+
squeeze_node = onnx.helper.make_node(
140+
"Squeeze", inputs=["x", "axes"], outputs=["z"], name="Squeeze_0"
141+
)
142+
then_branch = onnx.helper.make_graph(
143+
[squeeze_node],
144+
"then_branch",
145+
[],
146+
[onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [16, 16])],
147+
initializer=[axes_init],
148+
)
149+
# Build else_branch: Identity(x)
150+
identity_node = onnx.helper.make_node(
151+
"Identity", inputs=["x"], outputs=["z"], name="Identity_0"
152+
)
153+
else_branch = onnx.helper.make_graph(
154+
[identity_node],
155+
"else_branch",
156+
[],
157+
[onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [1, 16, 16])],
158+
)
159+
# Build main graph with a constant True condition
160+
one_node = onnx.helper.make_node(
161+
"Constant",
162+
inputs=[],
163+
outputs=["one"],
164+
value=onnx.helper.make_tensor("one", onnx.TensorProto.BOOL, [], [True]),
165+
)
166+
if_node = onnx.helper.make_node(
167+
"If",
168+
inputs=["one"],
169+
outputs=["z"],
170+
then_branch=then_branch,
171+
else_branch=else_branch,
172+
name="If_1",
173+
)
174+
main_graph = onnx.helper.make_graph(
175+
[one_node, if_node],
176+
"main",
177+
[onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, [1, 16, 16])],
178+
[onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [16, 16])],
179+
)
180+
proto = onnx.helper.make_model(
181+
main_graph, opset_imports=[onnx.helper.make_opsetid("", 17)]
182+
)
183+
proto.ir_version = 8
184+
return ir.serde.deserialize_model(proto)
185+
186+
def test_fold_if_cond_with_subgraph_initializer(self):
187+
"""Test that initializers in inlined If branches are moved to the main graph."""
188+
model = self._make_if_model_with_subgraph_initializer()
189+
self.assertIn("axes", model.graph[1].attributes["then_branch"].as_graph().initializers)
190+
self.assertNotIn("axes", model.graph.initializers)
191+
192+
optimized = self._fold(model)
193+
194+
# The If node should be inlined; the axes initializer must be in the main graph
195+
self.assertIn("axes", optimized.graph.initializers)
196+
np.testing.assert_array_equal(
197+
optimized.graph.initializers["axes"].const_value.numpy(),
198+
np.array([0], dtype=np.int64),
199+
)
200+
133201
def test_fold_inside_if_branch(self):
134202
model = """
135203
<ir_version: 7, opset_import: [ "" : 17]>

0 commit comments

Comments
 (0)