diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index e399b66100..575d206486 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - # NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files. - from __future__ import annotations __all__ = [ @@ -81,8 +79,6 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool: # The API below works only for non-control-flow ops (ops without any graph-attributes). # This currently used ONNX's reference implementation. But we could also # use ORT's implementation if we want to. - - def _process_constant_node(node: ir.Node) -> None: """Sets const_value of output value of a Constant op node.""" if not _is_onnx_op(node, "Constant"): @@ -126,7 +122,6 @@ def _process_constant_node(node: ir.Node) -> None: def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None: """Performs basic constant propagation for a sequence of nodes. - Just marks the output values of Constant op nodes with their const_value. """ for node in nodes: @@ -210,12 +205,10 @@ def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None: # The "partial evaluators" below are non-standard evaluators. They are used to perform # partial evaluation and/or static program analysis (abstract interpretation). - # A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns # a Replacement for the node or None (if no replacement is needed). It may also return just # the ir.Value or ir.Values to replace the output values of the node, when the new nodes # can be inferred from the RewriterContext used to build the new nodes. - RewriterContext = _tape.Builder ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue] @@ -471,7 +464,6 @@ def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnVa @register("Reshape") def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Reshape node by Identity when applicable. - Also propagate symbolic shape values. """ input = _get_input(node, 0) @@ -562,6 +554,27 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return op.Constant(value_int=size) +def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None: + """Move all initializers from src graph to dst graph, ensuring name uniqueness. + When an If branch is inlined into the main graph, the branch subgraph may + hold initializers (e.g. a constant axes tensor for a Squeeze node) that were + folded in a prior pass. Those initializers must be migrated to the main graph + so that the inlined nodes can still reference them; failing to do so leaves the + references dangling and produces an invalid model. + """ + counter: dict[str, int] = {} + for name in list(src.initializers): + initializer = src.initializers.pop(name) + # Ensure name uniqueness in the destination graph + new_name = name + while new_name in dst.initializers: + counter[name] = counter.get(name, 0) + 1 + new_name = f"{name}_{counter[name]}" + if new_name != name: + initializer.name = new_name + dst.register_initializer(initializer) + + @register("If") def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: cond_input = _get_input(node, 0) @@ -586,7 +599,6 @@ def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: if actual is not None } # TODO: Extend renaming to intermediate values. - def rename(name): return renamings.get(name, name) @@ -599,6 +611,15 @@ def rename(name): # Avoid name collision. sub_node.name = f"{node.name}_{sub_node.name}" + # Move initializers from the subgraph to the main graph to avoid losing them. + # When the If branch was processed in a prior constant-folding pass, any + # constants inside the branch (e.g. the 'axes' tensor for a Squeeze node) + # may have been folded into subgraph initializers. Without this step those + # initializers would be orphaned once the branch nodes are inlined here. + main_graph = node.graph + if main_graph is not None: + _move_initializers_to_graph(graph, main_graph) + # TODO: we should handle initializers as well! return Replacement(formal_outs, graph_nodes) return None @@ -787,7 +808,6 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu @register("SplitToSequence") def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Rewriting pattern. - From splits = onnx::SplitToSequence(input, split, axis=axis) @@ -965,14 +985,13 @@ def _record_contributing_values(original_node: ir.Node, replacement: Replacement class FoldConstantsPass(ir.passes.InPlacePass): """A pass that folds constant expressions in the model. - Attributes: shape_inference: Whether to perform shape inference. input_size_limit: Maximum size of input tensors to fold. output_size_limit: Maximum size of output tensors to fold. should_fold: An optional function that takes a node and returns True if the node should be considered for folding. - The function should return True/False value to indicate if this particular + The function should return True/False value to indicate if this particular node should be folded, or None to use the default folding rules. """ @@ -1201,7 +1220,6 @@ def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None: node.domain, node.op_type, ) - return None if _is_non_deterministic_op(node): @@ -1240,8 +1258,7 @@ def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None: for op_type in DEFAULT_CONSTANT_FOLD_BLACKLIST: if _is_onnx_op(node, op_type): logger.info( - "Skipping constant folding for node %r because " - "%s is preserved by default", + "Skipping constant folding for node %r because %s is preserved by default", node.name, op_type, ) @@ -1464,7 +1481,6 @@ def fold_constants( Returns: An instance of `FoldConstantsResult`. - """ folder_pass = FoldConstantsPass( shape_inference=onnx_shape_inference, @@ -1472,4 +1488,4 @@ def fold_constants( output_size_limit=output_size_limit, should_fold=should_fold, ) - return folder_pass(model) # type: ignore[return-value] + return folder_pass(model) \ No newline at end of file diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 60e1284066..f308402e70 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -14,6 +14,76 @@ class FoldConstantsTest(unittest.TestCase): + def test_fold_if_cond_with_subgraph_initializer(self): + model = ir.from_onnx_text(""" + + agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { + two = Constant () + three = Constant () + z = If (cond) < + then_branch = then_graph () => (then_z) { + temp = Add (two, three) + then_z = Mul (temp, x) + }, + else_branch = else_graph () => (else_z) { + else_z = Identity (x) + } + > + } + """) + + # Pass 1: fold Add(2.0, 3.0) into a subgraph initializer called 'temp'. + # The If condition is still non-constant so the branch is NOT inlined yet. + _constant_folding.fold_constants(model) + optimizer.remove_unused_nodes(model) + if_node = next(n for n in model.graph if n.op_type == "If") + then_branch = if_node.attributes["then_branch"].as_graph() + self.assertIn("temp", then_branch.initializers) + self.assertNotIn("temp", model.graph.initializers) + + # Make the condition a known True constant to trigger branch inlining. + const_true = ir.Value(name="const_true") + const_true.const_value = ir.Tensor(np.array(True)) + if_node.replace_input_with(0, const_true) + + # Pass 2: inline the If branch. + # 'temp' must be migrated from the subgraph to the main graph. + _constant_folding.fold_constants(model) + optimizer.remove_unused_nodes(model) + self.assertIn("temp", model.graph.initializers) + onnx.checker.check_model(ir.serde.serialize_model(model)) + + def test_fold_if_cond_with_subgraph_initializer_name_collision(self): + """Subgraph initializer names that clash with main-graph names get a unique suffix.""" + model = ir.from_onnx_text(""" + + agraph (float[1, 4] x, bool cond) => (float[4] z) { + axes_val = Constant () + z = If (cond) < + then_branch = then_branch_graph () => (then_z) { + axes_val_inner = Constant () + then_z = Squeeze (x, axes_val_inner) + }, + else_branch = else_branch_graph () => (else_z) { + else_z = Squeeze (x, axes_val) + } + > + } + """) + + _constant_folding.fold_constants(model) + optimizer.remove_unused_nodes(model) + + if_node = next(n for n in model.graph if n.op_type == "If") + const_true = ir.Value(name="const_true_collision") + const_true.const_value = ir.Tensor(np.array(True)) + if_node.replace_input_with(0, const_true) + + # Must not crash or silently overwrite on name collision. + _constant_folding.fold_constants(model) + optimizer.remove_unused_nodes(model) + onnx.checker.check_model(ir.serde.serialize_model(model)) + def _fold( self, model: ir.Model | str, @@ -236,9 +306,7 @@ def test_shape_inference(self): self.assertEqual(len(optimized.graph), 1) self.assertIn("C", optimized.graph.initializers) - def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( - self, - ): + def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(self): model = """ < ir_version: 8, @@ -260,8 +328,7 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_ # TODO: There is an unrelated limitation that `symbolic_value` is not # utilized when the value is only referenced by graph output. - # E.g., the following test model will not have this optimization - # applied. + # E.g., the following test model will not have this optimization applied. # # < # ir_version: 8, @@ -284,9 +351,7 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_ self.assertEqual(len(optimized.graph[-2].outputs), 4) self.assertEqual(optimized.graph[-2].op_type, "Split") - def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split( - self, - ): + def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split(self): model = """ < ir_version: 8, @@ -309,9 +374,7 @@ def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_sp self.assertEqual(len(optimized.graph[-2].outputs), 3) self.assertEqual(optimized.graph[-2].op_type, "Split") - def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze( - self, - ): + def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze(self): model = """ < ir_version: 8, @@ -334,9 +397,7 @@ def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_ self.assertEqual(optimized.graph[1].op_type, "Split") self.assertEqual(len([n for n in optimized.graph if n.op_type == "Squeeze"]), 3) - def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( - self, - ): + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0(self): model = """ < ir_version: 8, @@ -352,9 +413,7 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( self.assertEqual(len(optimized.graph), 3) self.assertEqual(optimized.graph[2].op_type, "Concat") - def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( - self, - ): + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(self): model = """ < ir_version: 8, @@ -736,8 +795,8 @@ def test_multi_graph_identity_output_preserves_output_name(self): self.assertEqual([input.name for input in optimized.graph.inputs], ["x"]) # This should not be constant-foldable as the constant references an - # attribute and thus the shape cannot be resolved. At the same time it - # should not fail due to the attribute value being None in + # attribute and thus the shape cannot be resolved. + # At the same time it should not fail due to the attribute value being None in # _process_constant_node def test_attribute_reference(self): model = """ @@ -798,4 +857,4 @@ def test_initializer_as_graph_output_is_not_removed(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file