@@ -148,18 +148,24 @@ class Replacement:
148148# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
149149# TODO: Add support for negative symbolic dimensions.
150150
151+ SymbolicValue = Union [ir .Value , list [ir .Value ], ir .Shape ]
152+
151153
152154class OptimizerState :
153155 def __init__ (self ):
154- self ._sym_value_map : dict [ir .Value , Any ] = {}
156+ self ._sym_value_map : dict [ir .Value , SymbolicValue ] = {}
155157 self ._initializer_inputs : list [set [ir .Value ]] = []
156158
157- def get_sym_value (self , value : ir .Value | None ) -> Any :
159+ @property
160+ def symbolic_value_map (self ) -> dict [ir .Value , SymbolicValue ]:
161+ return self ._sym_value_map
162+
163+ def get_sym_value (self , value : ir .Value | None ) -> SymbolicValue | None :
158164 if value is None :
159165 return None
160166 return self ._sym_value_map .get (value )
161167
162- def set_sym_value (self , value : ir .Value , sym_value : Any ) -> None :
168+ def set_sym_value (self , value : ir .Value , sym_value : SymbolicValue ) -> None :
163169 self ._sym_value_map [value ] = sym_value
164170
165171 def push_initializer_inputs (self ) -> None :
@@ -1094,7 +1100,17 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
10941100 for function in model .functions .values ():
10951101 # TODO(rama): Should we specialize functions?
10961102 self .visit_function (function )
1097- return ir .passes .PassResult (model , self .modified )
1103+ return FoldConstantsResult (model , self .modified , self ._state .symbolic_value_map )
1104+
1105+
1106+ @dataclasses .dataclass
1107+ class FoldConstantsResult (ir .passes .PassResult ):
1108+ symbolic_value_map : dict [ir .Value , SymbolicValue ]
1109+
1110+ # Add conversion to bool for backward compatibility. The previously returned value
1111+ # for the fold_constants method was a boolean indicating whether the model was modified.
1112+ def __bool__ (self ) -> bool :
1113+ return self .modified
10981114
10991115
11001116def fold_constants (
@@ -1104,23 +1120,31 @@ def fold_constants(
11041120 onnx_shape_inference : bool = False ,
11051121 input_size_limit : int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
11061122 output_size_limit : int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
1107- ) -> bool :
1123+ ) -> FoldConstantsResult :
11081124 """
11091125 Applies constant folding optimization to the model.
1110- Returns true iff the model was modified.
1126+
1127+ Args:
1128+ model: The ONNX model to optimize.
1129+ external_data_folder: Path to the folder containing external data
1130+ for the model. Defaults to an empty string.
1131+ onnx_shape_inference: Whether to enable ONNX shape inference during
1132+ constant folding. Defaults to False.
1133+ input_size_limit: The maximum size (in bytes) of input tensors
1134+ that can be considered for constant folding. Defaults to
1135+ `DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`.
1136+ output_size_limit: The maximum size (in bytes) of output tensors
1137+ that can be stored after constant folding. Defaults to
1138+ `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
1139+
1140+ Returns:
1141+ An instance of `FoldConstantsResult`.
1142+
11111143 """
11121144 folder_pass = FoldConstantsPass (
11131145 external_data_folder = external_data_folder ,
11141146 shape_inference = onnx_shape_inference ,
11151147 input_size_limit = input_size_limit ,
11161148 output_size_limit = output_size_limit ,
11171149 )
1118- folder_pass (model )
1119- for op in folder_pass .counts :
1120- logger .info (
1121- "Constant-folded '%s' %s times, with %s size." ,
1122- op ,
1123- folder_pass .counts [op ],
1124- folder_pass .sizes [op ],
1125- )
1126- return folder_pass .modified
1150+ return folder_pass (model ) # type: ignore[return-value]
0 commit comments