@@ -797,9 +797,7 @@ def merge_dims(dim1, dim2):
797797 return ir .Shape ([merge_dims (dim1 , dim2 ) for dim1 , dim2 in zip (shape1 , shape2 )])
798798
799799
800- class ConstantFolder :
801- opset_imports : dict [str , int ]
802-
800+ class FoldConstantsPass (ir .passes .PassBase ):
803801 def __init__ (
804802 self ,
805803 * ,
@@ -812,11 +810,17 @@ def __init__(
812810 self ._shape_inference = shape_inference
813811 self ._input_size_limit = input_size_limit
814812 self ._output_size_limit = output_size_limit
815- self ._init ()
816-
817- def _init (self ) -> None :
813+ self .opset_imports : dict [str , int ] = {}
818814 self .counts : dict [str , int ] = {}
819815 self .sizes : dict [str , int ] = {}
816+ self .modified : bool = False
817+ self ._state = OptimizerState ()
818+ self ._reset ()
819+
820+ def _reset (self ) -> None :
821+ """Reset internal states for a new run."""
822+ self .counts = {}
823+ self .sizes = {}
820824 self .modified = False
821825 self ._state = OptimizerState ()
822826
@@ -931,6 +935,7 @@ def process_node(self, node: ir.Node):
931935 sym_value .name ,
932936 )
933937 node .replace_input_with (i , sym_value )
938+ self .modified = True
934939 # TODO(rama): consider merging type/other info from both values
935940
936941 # Do incremental shape inference
@@ -1007,6 +1012,8 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
10071012 root , node , [node ], replacement .new_nodes , node .outputs , replacement .new_outputs
10081013 )
10091014
1015+ self .modified = True
1016+
10101017 # TODO: what about new opset_imports?
10111018 # TODO: track statistics about replaced nodes and sizes of new constants
10121019
@@ -1045,13 +1052,14 @@ def visit_function(self, function: ir.Function) -> None:
10451052 for node in function :
10461053 self .visit_node (node , function )
10471054
1048- def visit_model (self , model : ir .Model ) -> None :
1049- self ._init ()
1055+ def call (self , model : ir .Model ) -> ir . passes . PassResult :
1056+ self ._reset ()
10501057 self .opset_imports = model .opset_imports
10511058 self .visit_graph (model .graph )
10521059 for function in model .functions .values ():
10531060 # TODO(rama): Should we specialize functions?
10541061 self .visit_function (function )
1062+ return ir .passes .PassResult (model , self .modified )
10551063
10561064
10571065def fold_constants (
@@ -1066,18 +1074,18 @@ def fold_constants(
10661074 Applies constant folding optimization to the model.
10671075 Returns true iff the model was modified.
10681076 """
1069- folder = ConstantFolder (
1077+ folder_pass = FoldConstantsPass (
10701078 external_data_folder = external_data_folder ,
10711079 shape_inference = onnx_shape_inference ,
10721080 input_size_limit = input_size_limit ,
10731081 output_size_limit = output_size_limit ,
10741082 )
1075- folder . visit_model (model )
1076- for op in folder .counts :
1083+ folder_pass (model )
1084+ for op in folder_pass .counts :
10771085 logger .info (
10781086 "Constant-folded '%s' %s times, with %s size." ,
10791087 op ,
1080- folder .counts [op ],
1081- folder .sizes [op ],
1088+ folder_pass .counts [op ],
1089+ folder_pass .sizes [op ],
10821090 )
1083- return folder .modified
1091+ return folder_pass .modified
0 commit comments