22# Licensed under the MIT License.
33from __future__ import annotations
44
5+ from typing import TypeVar
6+
57__all__ = [
6- "fold_constants" ,
7- "fold_constants_ir" ,
8- "remove_unused_nodes" ,
9- "optimize" ,
10- "optimize_ir" ,
118 "basic_constant_propagation" ,
9+ "fold_constants_ir" ,
10+ "fold_constants" ,
1211 "inline" ,
12+ "optimize_ir" ,
13+ "optimize" ,
14+ "remove_unused_nodes" ,
1315]
1416
1517import onnx
1618
1719import onnxscript .ir .passes .common .inliner
1820import onnxscript .ir .passes .common .unused_removal
1921import onnxscript .optimizer ._constant_folding as constant_folding
20- import onnxscript .optimizer ._legacy ._optimizer as legacy_optimizer
21- import onnxscript .optimizer ._legacy .constant_folding as legacy_constant_folding
2222from onnxscript import ir
23+ from onnxscript .optimizer ._constant_folding import (
24+ basic_constant_propagation ,
25+ )
26+ from onnxscript .optimizer ._constant_folding import (
27+ fold_constants as fold_constants_ir ,
28+ )
2329from onnxscript .optimizer ._optimizer import optimize_ir
2430
25- basic_constant_propagation = constant_folding . basic_constant_propagation
26- fold_constants_ir = constant_folding . fold_constants
31+ _ModelProtoOrIr = TypeVar ( "_ModelProtoOrIr" , onnx . ModelProto , ir . Model )
32+
2733
34+ def optimize (
35+ model : _ModelProtoOrIr ,
36+ num_iterations : int = 2 ,
37+ * ,
38+ onnx_shape_inference : bool = True ,
39+ stop_if_no_change : bool = True ,
40+ input_size_limit : int = constant_folding .DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
41+ output_size_limit : int = constant_folding .DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
42+ inline : bool = True ,
43+ ) -> _ModelProtoOrIr :
44+ """Optimizes a model.
2845
29- def optimize (model : ir .Model , * args , ** kwargs ) -> ir .Model :
46+ Args:
47+ model: The model to be optimized.
48+ num_iterations: Number of times the optimization loop is repeated.
49+ onnx_shape_inference: Applies node-level shape-inference as part of optimization
50+ input_size_limit: Will not apply constant folding to ops with any input of size
51+ greater than this. Does not apply to special ops like Shape() and Size().
52+ output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
53+ of the output tensor is greater than this.
54+ stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
55+ inline: If True, inlines all functions in the model.
56+
57+ Returns:
58+ The optimized model. If the input was a ModelProto, the output will also be a
59+ ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
60+ """
3061 if isinstance (model , ir .Model ):
31- # In that case, this is done inplace.
32- optimize_ir (model , * args , ** kwargs )
62+ # In this case, optimize is done inplace.
63+ # TODO(justinchuby): Maybe make functional
64+ optimize_ir (
65+ model ,
66+ num_iterations = num_iterations ,
67+ onnx_shape_inference = onnx_shape_inference ,
68+ stop_if_no_change = stop_if_no_change ,
69+ input_size_limit = input_size_limit ,
70+ output_size_limit = output_size_limit ,
71+ inline = inline ,
72+ )
3373 return model
3474 else :
35- return legacy_optimizer .optimize (model , * args , ** kwargs )
75+ assert isinstance (model , onnx .ModelProto )
76+ model_ir = ir .serde .deserialize_model (model )
77+ optimize_ir (
78+ model_ir ,
79+ num_iterations = num_iterations ,
80+ onnx_shape_inference = onnx_shape_inference ,
81+ stop_if_no_change = stop_if_no_change ,
82+ input_size_limit = input_size_limit ,
83+ output_size_limit = output_size_limit ,
84+ inline = inline ,
85+ )
86+ # Move the model back to the proto
87+ new_proto = ir .serde .serialize_model (model_ir )
88+ return new_proto
3689
3790
3891def inline (model : ir .Model ) -> None :
@@ -43,11 +96,20 @@ def inline(model: ir.Model) -> None:
4396
4497def fold_constants (
4598 model : ir .Model | onnx .ModelProto , * args , ** kwargs
46- ) -> constant_folding .FoldConstantsResult | bool :
99+ ) -> constant_folding .FoldConstantsResult :
100+ """Fold constants in a model in place."""
47101 if isinstance (model , ir .Model ):
48102 return constant_folding .fold_constants (model , * args , ** kwargs )
49103 else :
50- return legacy_constant_folding .fold_constants (model , * args , ** kwargs )
104+ assert isinstance (model , onnx .ModelProto )
105+ model_proto = model
106+ model = ir .serde .deserialize_model (model_proto )
107+ result = constant_folding .fold_constants (model , * args , ** kwargs )
108+ # Move the model back to the proto
109+ new_proto = ir .serde .serialize_model (model )
110+ model_proto .Clear ()
111+ model_proto .CopyFrom (new_proto )
112+ return result
51113
52114
53115def remove_unused_nodes (model : ir .Model | onnx .ModelProto ) -> None :
0 commit comments