11# Copyright (c) Microsoft Corporation. All rights reserved.
22# Licensed under the MIT License.
3-
43# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files.
5-
64from __future__ import annotations
75
86__all__ = [
@@ -81,8 +79,6 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
8179# The API below works only for non-control-flow ops (ops without any graph-attributes).
8280# This currently used ONNX's reference implementation. But we could also
8381# use ORT's implementation if we want to.
84-
85-
8682def _process_constant_node (node : ir .Node ) -> None :
8783 """Sets const_value of output value of a Constant op node."""
8884 if not _is_onnx_op (node , "Constant" ):
@@ -126,7 +122,6 @@ def _process_constant_node(node: ir.Node) -> None:
126122
127123def basic_constant_propagation (nodes : Iterable [ir .Node ]) -> None :
128124 """Performs basic constant propagation for a sequence of nodes.
129-
130125 Just marks the output values of Constant op nodes with their const_value.
131126 """
132127 for node in nodes :
@@ -210,12 +205,10 @@ def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
210205
211206# The "partial evaluators" below are non-standard evaluators. They are used to perform
212207# partial evaluation and/or static program analysis (abstract interpretation).
213-
214208# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns
215209# a Replacement for the node or None (if no replacement is needed). It may also return just
216210# the ir.Value or ir.Values to replace the output values of the node, when the new nodes
217211# can be inferred from the RewriterContext used to build the new nodes.
218-
219212RewriterContext = _tape .Builder
220213ReturnValue = Union [Replacement , Sequence [ir .Value ], ir .Value , None ]
221214PartialEvaluatorFunction = Callable [[ir .Node , RewriterContext , OptimizerState ], ReturnValue ]
@@ -471,7 +464,6 @@ def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnVa
471464@register ("Reshape" )
472465def reshape (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
473466 """Replace a Reshape node by Identity when applicable.
474-
475467 Also propagate symbolic shape values.
476468 """
477469 input = _get_input (node , 0 )
@@ -562,6 +554,27 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
562554 return op .Constant (value_int = size )
563555
564556
557+ def _move_initializers_to_graph (src : ir .Graph , dst : ir .Graph ) -> None :
558+ """Move all initializers from src graph to dst graph, ensuring name uniqueness.
559+ When an If branch is inlined into the main graph, the branch subgraph may
560+ hold initializers (e.g. a constant axes tensor for a Squeeze node) that were
561+ folded in a prior pass. Those initializers must be migrated to the main graph
562+ so that the inlined nodes can still reference them; failing to do so leaves the
563+ references dangling and produces an invalid model.
564+ """
565+ counter : dict [str , int ] = {}
566+ for name in list (src .initializers ):
567+ initializer = src .initializers .pop (name )
568+ # Ensure name uniqueness in the destination graph
569+ new_name = name
570+ while new_name in dst .initializers :
571+ counter [name ] = counter .get (name , 0 ) + 1
572+ new_name = f"{ name } _{ counter [name ]} "
573+ if new_name != name :
574+ initializer .name = new_name
575+ dst .register_initializer (initializer )
576+
577+
565578@register ("If" )
566579def if_op (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
567580 cond_input = _get_input (node , 0 )
@@ -586,7 +599,6 @@ def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
586599 if actual is not None
587600 }
588601 # TODO: Extend renaming to intermediate values.
589-
590602 def rename (name ):
591603 return renamings .get (name , name )
592604
@@ -599,6 +611,15 @@ def rename(name):
599611 # Avoid name collision.
600612 sub_node .name = f"{ node .name } _{ sub_node .name } "
601613
614+ # Move initializers from the subgraph to the main graph to avoid losing them.
615+ # When the If branch was processed in a prior constant-folding pass, any
616+ # constants inside the branch (e.g. the 'axes' tensor for a Squeeze node)
617+ # may have been folded into subgraph initializers. Without this step those
618+ # initializers would be orphaned once the branch nodes are inlined here.
619+ main_graph = node .graph
620+ if main_graph is not None :
621+ _move_initializers_to_graph (graph , main_graph )
622+
602623 # TODO: we should handle initializers as well!
603624 return Replacement (formal_outs , graph_nodes )
604625 return None
@@ -787,7 +808,6 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu
787808@register ("SplitToSequence" )
788809def split_to_sequence (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
789810 """Rewriting pattern.
790-
791811 From
792812
793813 splits = onnx::SplitToSequence(input, split, axis=axis)
@@ -965,14 +985,13 @@ def _record_contributing_values(original_node: ir.Node, replacement: Replacement
965985
966986class FoldConstantsPass (ir .passes .InPlacePass ):
967987 """A pass that folds constant expressions in the model.
968-
969988 Attributes:
970989 shape_inference: Whether to perform shape inference.
971990 input_size_limit: Maximum size of input tensors to fold.
972991 output_size_limit: Maximum size of output tensors to fold.
973992 should_fold: An optional function that takes a node and returns True if
974993 the node should be considered for folding.
975- The function should return True/False value to indicate if this particular
994+ The function should return True/False value to indicate if this particular
976995 node should be folded, or None to use the default folding rules.
977996 """
978997
@@ -1201,7 +1220,6 @@ def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None:
12011220 node .domain ,
12021221 node .op_type ,
12031222 )
1204-
12051223 return None
12061224
12071225 if _is_non_deterministic_op (node ):
@@ -1240,8 +1258,7 @@ def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None:
12401258 for op_type in DEFAULT_CONSTANT_FOLD_BLACKLIST :
12411259 if _is_onnx_op (node , op_type ):
12421260 logger .info (
1243- "Skipping constant folding for node %r because "
1244- "%s is preserved by default" ,
1261+ "Skipping constant folding for node %r because %s is preserved by default" ,
12451262 node .name ,
12461263 op_type ,
12471264 )
@@ -1464,12 +1481,11 @@ def fold_constants(
14641481
14651482 Returns:
14661483 An instance of `FoldConstantsResult`.
1467-
14681484 """
14691485 folder_pass = FoldConstantsPass (
14701486 shape_inference = onnx_shape_inference ,
14711487 input_size_limit = input_size_limit ,
14721488 output_size_limit = output_size_limit ,
14731489 should_fold = should_fold ,
14741490 )
1475- return folder_pass (model ) # type: ignore[return-value]
1491+ return folder_pass (model )
0 commit comments