Skip to content

Commit 0453e99

Browse files
Modify constant-folder to return computed symbolic value map (#2172)
Modify constant-folder to return computed symbolic value map, which may be useful to the caller. (Eg., fusion optimizations can make use of this information.) --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent d2b3758 commit 0453e99

2 files changed

Lines changed: 42 additions & 16 deletions

File tree

onnxscript/optimizer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
3535
return legacy_optimizer.optimize(model, *args, **kwargs)
3636

3737

38-
def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs) -> bool:
38+
def fold_constants(
39+
model: ir.Model | onnx.ModelProto, *args, **kwargs
40+
) -> constant_folding.FoldConstantsResult | bool:
3941
if isinstance(model, ir.Model):
4042
return constant_folding.fold_constants(model, *args, **kwargs)
4143
else:

onnxscript/optimizer/_constant_folding.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,24 @@ class Replacement:
148148
# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
149149
# TODO: Add support for negative symbolic dimensions.
150150

151+
SymbolicValue = Union[ir.Value, list[ir.Value], ir.Shape]
152+
151153

152154
class OptimizerState:
153155
def __init__(self):
154-
self._sym_value_map: dict[ir.Value, Any] = {}
156+
self._sym_value_map: dict[ir.Value, SymbolicValue] = {}
155157
self._initializer_inputs: list[set[ir.Value]] = []
156158

157-
def get_sym_value(self, value: ir.Value | None) -> Any:
159+
@property
160+
def symbolic_value_map(self) -> dict[ir.Value, SymbolicValue]:
161+
return self._sym_value_map
162+
163+
def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None:
158164
if value is None:
159165
return None
160166
return self._sym_value_map.get(value)
161167

162-
def set_sym_value(self, value: ir.Value, sym_value: Any) -> None:
168+
def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
163169
self._sym_value_map[value] = sym_value
164170

165171
def push_initializer_inputs(self) -> None:
@@ -1094,7 +1100,17 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
10941100
for function in model.functions.values():
10951101
# TODO(rama): Should we specialize functions?
10961102
self.visit_function(function)
1097-
return ir.passes.PassResult(model, self.modified)
1103+
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)
1104+
1105+
1106+
@dataclasses.dataclass
1107+
class FoldConstantsResult(ir.passes.PassResult):
1108+
symbolic_value_map: dict[ir.Value, SymbolicValue]
1109+
1110+
# Add conversion to bool for backward compatibility. The previously returned value
1111+
# for the fold_constants method was a boolean indicating whether the model was modified.
1112+
def __bool__(self) -> bool:
1113+
return self.modified
10981114

10991115

11001116
def fold_constants(
@@ -1104,23 +1120,31 @@ def fold_constants(
11041120
onnx_shape_inference: bool = False,
11051121
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
11061122
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
1107-
) -> bool:
1123+
) -> FoldConstantsResult:
11081124
"""
11091125
Applies constant folding optimization to the model.
1110-
Returns true iff the model was modified.
1126+
1127+
Args:
1128+
model: The ONNX model to optimize.
1129+
external_data_folder: Path to the folder containing external data
1130+
for the model. Defaults to an empty string.
1131+
onnx_shape_inference: Whether to enable ONNX shape inference during
1132+
constant folding. Defaults to False.
1133+
input_size_limit: The maximum size (in bytes) of input tensors
1134+
that can be considered for constant folding. Defaults to
1135+
`DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`.
1136+
output_size_limit: The maximum size (in bytes) of output tensors
1137+
that can be stored after constant folding. Defaults to
1138+
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
1139+
1140+
Returns:
1141+
An instance of `FoldConstantsResult`.
1142+
11111143
"""
11121144
folder_pass = FoldConstantsPass(
11131145
external_data_folder=external_data_folder,
11141146
shape_inference=onnx_shape_inference,
11151147
input_size_limit=input_size_limit,
11161148
output_size_limit=output_size_limit,
11171149
)
1118-
folder_pass(model)
1119-
for op in folder_pass.counts:
1120-
logger.info(
1121-
"Constant-folded '%s' %s times, with %s size.",
1122-
op,
1123-
folder_pass.counts[op],
1124-
folder_pass.sizes[op],
1125-
)
1126-
return folder_pass.modified
1150+
return folder_pass(model) # type: ignore[return-value]

0 commit comments

Comments
 (0)