Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,21 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return op.Constant(value_int=size)


def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None:
"""Move all initializers from src graph to dst graph, ensuring name uniqueness."""
counter: dict[str, int] = {}
for name in list(src.initializers):
initializer = src.initializers.pop(name)
# Ensure name uniqueness in the destination graph
new_name = name
while new_name in dst.initializers:
counter[name] = counter.get(name, 0) + 1
new_name = f"{name}_{counter[name]}"
if new_name != name:
initializer.name = new_name
dst.register_initializer(initializer)


@register("If")
def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
cond_input = _get_input(node, 0)
Expand Down Expand Up @@ -598,7 +613,11 @@ def rename(name):
# Avoid name collision.
sub_node.name = f"{node.name}_{sub_node.name}"

# TODO: we should handle initializers as well!
# Move initializers from the subgraph to the main graph to avoid losing them.
main_graph = node.graph
if main_graph is not None:
_move_initializers_to_graph(graph, main_graph)

return Replacement(formal_outs, graph_nodes)
return None

Expand Down
68 changes: 68 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np
import onnx
import onnx.helper
import onnx.numpy_helper
import parameterized

import onnxscript.optimizer as optimizer
Expand Down Expand Up @@ -130,6 +132,72 @@ def test_fold_if_cond(self):
self.assertEqual(optimized.graph[0].outputs[0].name, "z")
self.assertEqual(optimized.graph[0].op_type, "Mul")

def _make_if_model_with_subgraph_initializer(self) -> ir.Model:
"""Build a model where the then_branch of an If node has an initializer."""
# Build then_branch: Squeeze(x, axes) where axes=[0] is an initializer
axes_init = onnx.numpy_helper.from_array(np.array([0], dtype=np.int64), name="axes")
squeeze_node = onnx.helper.make_node(
"Squeeze", inputs=["x", "axes"], outputs=["z"], name="Squeeze_0"
)
then_branch = onnx.helper.make_graph(
[squeeze_node],
"then_branch",
[],
[onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [16, 16])],
initializer=[axes_init],
)
# Build else_branch: Identity(x)
identity_node = onnx.helper.make_node(
"Identity", inputs=["x"], outputs=["z"], name="Identity_0"
)
else_branch = onnx.helper.make_graph(
[identity_node],
"else_branch",
[],
[onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [1, 16, 16])],
)
# Build main graph with a constant True condition
one_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["one"],
value=onnx.helper.make_tensor("one", onnx.TensorProto.BOOL, [], [True]),
)
if_node = onnx.helper.make_node(
"If",
inputs=["one"],
outputs=["z"],
then_branch=then_branch,
else_branch=else_branch,
name="If_1",
)
main_graph = onnx.helper.make_graph(
[one_node, if_node],
"main",
[onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, [1, 16, 16])],
[onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [16, 16])],
)
proto = onnx.helper.make_model(
main_graph, opset_imports=[onnx.helper.make_opsetid("", 17)]
)
proto.ir_version = 8
Comment thread
justinchuby marked this conversation as resolved.
Outdated
return ir.serde.deserialize_model(proto)

def test_fold_if_cond_with_subgraph_initializer(self):
"""Test that initializers in inlined If branches are moved to the main graph."""
model = self._make_if_model_with_subgraph_initializer()
self.assertIn("axes", model.graph[1].attributes["then_branch"].as_graph().initializers)
self.assertNotIn("axes", model.graph.initializers)

optimized = self._fold(model)

# The If node should be inlined; the axes initializer must be in the main graph
self.assertIn("axes", optimized.graph.initializers)
np.testing.assert_array_equal(
optimized.graph.initializers["axes"].const_value.numpy(),
np.array([0], dtype=np.int64),
)

def test_fold_inside_if_branch(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
Expand Down