Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import onnxscript.ir.passes.common as common_passes
from onnxscript import ir
from onnxscript.rewriter import (
basic_rules,
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
llama_rule_sets,
no_op,
pattern,
)
Expand All @@ -31,7 +31,7 @@
gemm_to_matmul_add.rule, # type: ignore[has-type]
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*llama_rule_sets.llama_p0_rule_set().rules,
*basic_rules.basic_optimization_rules().rules,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Basic rewrite rules for general optimization patterns.

This module contains fundamental optimization rules that are generally applicable
to most ONNX models, including cast elimination, transpose simplification,
shape operation fusion, and other common patterns.
"""
from __future__ import annotations

from typing import ClassVar, Sequence
Expand Down Expand Up @@ -271,6 +277,7 @@
return check_result


# Create rule instances
cast_cast_rule = CastCast.rule()
cast_identity_rule = CastIdentity.rule()
expand_identity_rule = ExpandIdentity.rule()
Expand All @@ -282,24 +289,31 @@
squeeze_reshape_1d_rule = SqueezeReshape.rule()


def llama_p0_rule_set() -> orp.RewriteRuleSet:
"""Returns a set of rules which should be applied
before any other one as they usually remove unnecessary computation
such as the multiplication by 1 or two consecutive transpose.
def basic_optimization_rules() -> orp.RewriteRuleSet:
"""Returns a set of basic optimization rules.

Comment thread Fixed
Comment thread Fixed
These rules perform fundamental optimizations such as:
- Eliminating redundant cast operations
- Simplifying consecutive operations of the same type
- Removing identity operations
- Optimizing shape manipulation operations

Comment thread Fixed
Comment thread Fixed
These rules are generally safe to apply as a first optimization pass
before other more specialized optimizations.

Returns:
RewriteRuleSet
RewriteRuleSet: A collection of basic optimization rules
"""
return orp.RewriteRuleSet(
[
cast_cast_rule,
cast_identity_rule,
expand_identity_rule,
reshape_reshape_rule,
slice_split_rule, # Affect collapse slices rules?
slice_split_rule,
transpose_identity_rule,
transpose_transpose_rule,
unsqueeze_unsqueeze_rule,
squeeze_reshape_1d_rule,
]
)
)
Comment thread Fixed
Loading
Loading