Skip to content

Commit 2c87912

Browse files
authored
Merge branch 'main' into titaiwang/add_constant_to_initilizer_pass
2 parents 22df674 + 0453e99 commit 2c87912

15 files changed

Lines changed: 102 additions & 1632 deletions

onnxscript/_internal/version_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,6 @@ def transformers_older_than(version: str) -> bool | None:
4343
)
4444

4545

46-
def is_onnxruntime_training() -> bool:
47-
"""Returns True if the onnxruntime is onnxruntime-training."""
48-
try:
49-
from onnxruntime import training # pylint: disable=import-outside-toplevel
50-
51-
assert training
52-
except ImportError:
53-
# onnxruntime not training
54-
return False
55-
56-
try:
57-
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
58-
OrtValueVector,
59-
)
60-
except ImportError:
61-
return False
62-
63-
return hasattr(OrtValueVector, "push_back_batch")
64-
65-
6646
def onnxruntime_older_than(version: str) -> bool:
6747
"""Returns True if the onnxruntime version is older than the given version."""
6848
import onnxruntime # pylint: disable=import-outside-toplevel

onnxscript/optimizer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
3535
return legacy_optimizer.optimize(model, *args, **kwargs)
3636

3737

38-
def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs) -> bool:
38+
def fold_constants(
39+
model: ir.Model | onnx.ModelProto, *args, **kwargs
40+
) -> constant_folding.FoldConstantsResult | bool:
3941
if isinstance(model, ir.Model):
4042
return constant_folding.fold_constants(model, *args, **kwargs)
4143
else:

onnxscript/optimizer/_constant_folding.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,24 @@ class Replacement:
148148
# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
149149
# TODO: Add support for negative symbolic dimensions.
150150

151+
SymbolicValue = Union[ir.Value, list[ir.Value], ir.Shape]
152+
151153

152154
class OptimizerState:
153155
def __init__(self):
154-
self._sym_value_map: dict[ir.Value, Any] = {}
156+
self._sym_value_map: dict[ir.Value, SymbolicValue] = {}
155157
self._initializer_inputs: list[set[ir.Value]] = []
156158

157-
def get_sym_value(self, value: ir.Value | None) -> Any:
159+
@property
160+
def symbolic_value_map(self) -> dict[ir.Value, SymbolicValue]:
161+
return self._sym_value_map
162+
163+
def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None:
158164
if value is None:
159165
return None
160166
return self._sym_value_map.get(value)
161167

162-
def set_sym_value(self, value: ir.Value, sym_value: Any) -> None:
168+
def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
163169
self._sym_value_map[value] = sym_value
164170

165171
def push_initializer_inputs(self) -> None:
@@ -1094,7 +1100,17 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
10941100
for function in model.functions.values():
10951101
# TODO(rama): Should we specialize functions?
10961102
self.visit_function(function)
1097-
return ir.passes.PassResult(model, self.modified)
1103+
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)
1104+
1105+
1106+
@dataclasses.dataclass
1107+
class FoldConstantsResult(ir.passes.PassResult):
1108+
symbolic_value_map: dict[ir.Value, SymbolicValue]
1109+
1110+
# Add conversion to bool for backward compatibility. The previously returned value
1111+
# for the fold_constants method was a boolean indicating whether the model was modified.
1112+
def __bool__(self) -> bool:
1113+
return self.modified
10981114

10991115

11001116
def fold_constants(
@@ -1104,23 +1120,31 @@ def fold_constants(
11041120
onnx_shape_inference: bool = False,
11051121
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
11061122
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
1107-
) -> bool:
1123+
) -> FoldConstantsResult:
11081124
"""
11091125
Applies constant folding optimization to the model.
1110-
Returns true iff the model was modified.
1126+
1127+
Args:
1128+
model: The ONNX model to optimize.
1129+
external_data_folder: Path to the folder containing external data
1130+
for the model. Defaults to an empty string.
1131+
onnx_shape_inference: Whether to enable ONNX shape inference during
1132+
constant folding. Defaults to False.
1133+
input_size_limit: The maximum size (in bytes) of input tensors
1134+
that can be considered for constant folding. Defaults to
1135+
`DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`.
1136+
output_size_limit: The maximum size (in bytes) of output tensors
1137+
that can be stored after constant folding. Defaults to
1138+
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
1139+
1140+
Returns:
1141+
An instance of `FoldConstantsResult`.
1142+
11111143
"""
11121144
folder_pass = FoldConstantsPass(
11131145
external_data_folder=external_data_folder,
11141146
shape_inference=onnx_shape_inference,
11151147
input_size_limit=input_size_limit,
11161148
output_size_limit=output_size_limit,
11171149
)
1118-
folder_pass(model)
1119-
for op in folder_pass.counts:
1120-
logger.info(
1121-
"Constant-folded '%s' %s times, with %s size.",
1122-
op,
1123-
folder_pass.counts[op],
1124-
folder_pass.sizes[op],
1125-
)
1126-
return folder_pass.modified
1150+
return folder_pass(model) # type: ignore[return-value]

onnxscript/optimizer/_inliner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
190190

191191
class InlinePass(ir.passes.InPlacePass):
192192
def __init__(self) -> None:
193+
super().__init__()
193194
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
194195
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
195196
self._opset_imports: dict[str, int] = {}

onnxscript/optimizer/_legacy/_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
inline_simple_functions,
1616
)
1717
from onnxscript.optimizer._legacy.constant_folding import fold_constants
18-
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES
1918

2019
logger = logging.getLogger(__name__)
2120

@@ -75,7 +74,7 @@ def optimize(
7574
onnxscript.optimizer.remove_unused_functions(model)
7675
inline_functions_with_unused_outputs(model)
7776
# NOTE: This is general rewrite rules
78-
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
77+
model = rewriter.rewrite(model)
7978
if stop_if_no_change and not modified:
8079
logger.debug("Stopping after %d iterations.", _)
8180
break

onnxscript/optimizer/_optimizer.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,11 @@
66

77
import onnxscript.ir.passes.common.constant_manipulation
88
import onnxscript.ir.passes.common.unused_removal
9-
import onnxscript.optimizer
109
from onnxscript import ir, rewriter
1110
from onnxscript.optimizer import _constant_folding, _inliner
12-
from onnxscript.rewriter import (
13-
broadcast_to_matmul,
14-
cast_constant_of_shape,
15-
collapse_slices,
16-
gemm_to_matmul_add,
17-
llama_rule_sets,
18-
no_op,
19-
)
2011

2112
logger = logging.getLogger(__name__)
2213

23-
_DEFAULT_REWRITE_RULES: tuple[rewriter.pattern.RewriteRule, ...] = (
24-
*no_op.rules.rules, # TODO: merge this rule into constant folding?
25-
*broadcast_to_matmul.rules.rules,
26-
gemm_to_matmul_add.rule, # type: ignore[has-type]
27-
*cast_constant_of_shape.rules.rules,
28-
*collapse_slices.rules.rules,
29-
*llama_rule_sets.llama_p0_rule_set().rules,
30-
)
31-
3214

3315
def optimize_ir(
3416
model: ir.Model,
@@ -62,7 +44,7 @@ def optimize_ir(
6244
input_size_limit=input_size_limit,
6345
output_size_limit=output_size_limit,
6446
),
65-
rewriter.RewritePass(_DEFAULT_REWRITE_RULES),
47+
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
6648
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
6749
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(),
6850
onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(),

onnxscript/rewriter/__init__.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,46 +5,81 @@
55
from typing import Sequence, TypeVar, Union
66

77
__all__ = [
8-
# Modules
98
"pattern",
10-
# Functions
119
"rewrite",
10+
"RewritePass",
1211
]
1312

1413
import onnx
1514

1615
from onnxscript import ir
1716
from onnxscript.ir.passes.common import unused_removal
18-
from onnxscript.rewriter import pattern
17+
from onnxscript.rewriter import (
18+
broadcast_to_matmul,
19+
cast_constant_of_shape,
20+
collapse_slices,
21+
gemm_to_matmul_add,
22+
llama_rule_sets,
23+
no_op,
24+
pattern,
25+
)
1926

20-
PatternRewriteRule = pattern.RewriteRule
21-
22-
ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model)
27+
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
28+
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
29+
*no_op.rules.rules, # TODO: merge this rule into constant folding?
30+
*broadcast_to_matmul.rules.rules,
31+
gemm_to_matmul_add.rule, # type: ignore[has-type]
32+
*cast_constant_of_shape.rules.rules,
33+
*collapse_slices.rules.rules,
34+
*llama_rule_sets.llama_p0_rule_set().rules,
35+
)
2336

2437

2538
class RewritePass(ir.passes.InPlacePass):
2639
def __init__(
2740
self,
28-
pattern_rewrite_rules: Sequence[PatternRewriteRule] | pattern.RewriteRuleSet = (),
41+
rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet,
42+
/,
2943
) -> None:
30-
if pattern_rewrite_rules:
31-
if not isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet):
32-
# Create a pattern rule-set using provided rules
33-
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
34-
assert isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet)
35-
self.pattern_rewrite_rules: pattern.RewriteRuleSet = pattern_rewrite_rules
44+
super().__init__()
45+
if isinstance(rules, Sequence):
46+
if not rules:
47+
raise ValueError("rules must not be empty")
48+
# Create a pattern rule-set using provided rules
49+
rules = pattern.RewriteRuleSet(rules)
50+
assert isinstance(rules, pattern.RewriteRuleSet)
51+
self.rules: pattern.RewriteRuleSet = rules
3652

3753
def call(self, model: ir.Model) -> ir.passes.PassResult:
38-
count = self.pattern_rewrite_rules.apply_to_model(model)
54+
count = self.rules.apply_to_model(model)
3955
if count:
4056
print(f"Applied {count} of general pattern rewrite rules.")
4157
return ir.passes.PassResult(model, bool(count))
4258

4359

4460
def rewrite(
45-
model: ModelProtoOrIr,
46-
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], pattern.RewriteRuleSet] = (),
47-
) -> ModelProtoOrIr:
61+
model: _ModelProtoOrIr,
62+
pattern_rewrite_rules: Union[Sequence[pattern.RewriteRule], pattern.RewriteRuleSet]
63+
| None = None,
64+
) -> _ModelProtoOrIr:
65+
"""Rewrite the model using the provided pattern rewrite rules.
66+
67+
Unused nodes, functions, and opsets will be removed after the rewrite.
68+
69+
Args:
70+
model: The model to be rewritten. Can be an ONNX ModelProto or an ir.Model.
71+
pattern_rewrite_rules: A sequence of pattern rewrite rules or a RewriteRuleSet.
72+
If not provided, default rules will be applied. If empty, no rules will be applied
73+
and the original model will be returned.
74+
75+
Returns:
76+
The rewritten model as the same type as the input model.
77+
"""
78+
if pattern_rewrite_rules is None:
79+
pattern_rewrite_rules = _DEFAULT_REWRITE_RULES
80+
elif not pattern_rewrite_rules:
81+
return model
82+
4883
if isinstance(model, onnx.ModelProto):
4984
model_ir = ir.serde.deserialize_model(model)
5085
proto = True

onnxscript/rewriter/pattern.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,13 +1664,18 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:
16641664

16651665
class RewriteRuleSet:
16661666
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:
1667+
if not rules:
1668+
raise ValueError("rules must contain at least one rule")
16671669
if commute:
16681670
rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules]))
16691671
self.rules = rules
16701672
# We call remove_unused_nodes at end of rewriting if there is any rule that does
16711673
# NOT remove nodes (immediately when it is applied)
16721674
self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules)
16731675

1676+
def __repr__(self) -> str:
1677+
return f"{self.__class__.__name__}({self.rules})"
1678+
16741679
def _apply_to_graph_or_function(
16751680
self,
16761681
model: ir.Model,

onnxscript/tools/benchmark/__init__.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

0 commit comments

Comments
 (0)