Skip to content

Commit f5327f8

Browse files
Add BiasGelu, Erfgelu and SkipLayerNormalization fusions (#2222)
This pull request introduces new fusion patterns and enhancements to the ONNXScript rewriter module, focusing on optimization and test coverage improvements. The key changes include adding support for `BiasGelu` and additional `ErfGelu` patterns, extending `SkipLayerNormalization` to handle bias addition, and updating test utilities for better accuracy validation. ### New fusion patterns: * **BiasGelu Fusion**: Added a new fusion pattern for `BiasGelu` operations, including its implementation in `onnxscript/rewriter/ort_fusions/bias_gelu.py` and integration into the `fuse_xformers` pipeline. A corresponding unit test was added to validate the functionality. [[1]](diffhunk://#diff-bae885e012eac8fcd8bb223ffcd1ad12032d9567274c47c96e3bc7359976f201R1-R22) [[2]](diffhunk://#diff-7ed8fc913d266194ed4adf06143954a9f5c0b5170ac6a813faf09b1159899394R16-R18) [[3]](diffhunk://#diff-7ed8fc913d266194ed4adf06143954a9f5c0b5170ac6a813faf09b1159899394R90) [[4]](diffhunk://#diff-d86ef6d0ede3ff678737083487de1363cf2e9b79b0bcb93cb76db343c0a9e450R1-R52) * **ErfGelu Enhancements**: Introduced a second pattern for `ErfGelu` fusion and refactored the corresponding implementation. The file was renamed from `erfgelu.py` to `ort_fusions/erfgelu.py` for consistency. [[1]](diffhunk://#diff-5b7be33fd11491135b99b58bfb5caad2458fde98364c99875dfd8739cb38ec2eL5-R9) [[2]](diffhunk://#diff-5b7be33fd11491135b99b58bfb5caad2458fde98364c99875dfd8739cb38ec2eR22-R36) [[3]](diffhunk://#diff-7ed8fc913d266194ed4adf06143954a9f5c0b5170ac6a813faf09b1159899394R16-R18) [[4]](diffhunk://#diff-7ed8fc913d266194ed4adf06143954a9f5c0b5170ac6a813faf09b1159899394R70) ### Enhancements to existing fusions: * **SkipLayerNormalization with Bias**: Extended the `SkipLayerNormalization` fusion to support an additional bias term. This includes new patterns and rewrite rules in `onnxscript/rewriter/ort_fusions/skip_normalization.py`. ### Test utility updates: * **Tolerance Adjustment**: Increased the relative and absolute tolerances in `assert_allclose` to `1e-3` for better handling of numerical discrepancies in tests.
1 parent 3af94a7 commit f5327f8

6 files changed

Lines changed: 136 additions & 10 deletions

File tree

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
softmax,
1414
)
1515
from onnxscript.rewriter.ort_fusions.attention import fuse_attention
16+
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
1617
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
18+
from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu
1719
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
1820
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
1921
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
@@ -65,6 +67,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
6567
fusion_count = dict()
6668

6769
model = _pre_optimize(model)
70+
fusion_count["erf_gelu"] = fuse_erfgelu(model)
6871
fusion_count["rms_normalization"] = fuse_rms_normalization(model)
6972
fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model)
7073
fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model)
@@ -84,6 +87,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
8487
fusion_count["attention"] = fuse_attention(model)
8588
fusion_count["gqa"] = 0
8689
fusion_count["gelu"] = fuse_gelu(model)
90+
fusion_count["bias_gelu"] = fuse_bias_gelu(model)
8791
# Finally: inline any intermediate fusion functions introduced that were not
8892
# consumed by other fusions, and eliminate any remaining unused nodes.
8993
optimize(model)

onnxscript/rewriter/ort_fusions/_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ort_run(model_name: str, model, inputs):
3333
return session.run(None, inputs)
3434

3535

36-
def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
36+
def assert_allclose(outputs, expected_outputs, rtol=1e-3, atol=1e-3):
3737
for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)):
3838
try:
3939
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from onnxscript.rewriter import _fusion_utils, pattern
6+
7+
8+
class BiasGeluFusion(pattern.RewriteRuleClassBase):
9+
def pattern(self, op, x, y):
10+
gelu_add = op.Add(x, y)
11+
return op.Gelu(gelu_add, _domain="com.microsoft")
12+
13+
def rewrite(self, op, x, y):
14+
return op.BiasGelu(x, y, _domain="com.microsoft")
15+
16+
17+
_rule = BiasGeluFusion.rule()
18+
19+
bias_gelu_rules = pattern.RewriteRuleSet([_rule])
20+
21+
22+
fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import unittest
5+
6+
import numpy as np
7+
8+
import onnxscript
9+
import onnxscript.ir as ir
10+
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
11+
from onnxscript import FLOAT, script
12+
from onnxscript import opset18 as op
13+
from onnxscript.optimizer import optimize, remove_unused_nodes
14+
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
15+
16+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
17+
18+
19+
class BiasGeluFusionTest(unittest.TestCase):
20+
def test_bias_gelu_fusion(self):
21+
@script()
22+
def bias_gelu_model(x, y):
23+
gelu_add = op.Add(x, y)
24+
gelu = msft_op.Gelu(gelu_add)
25+
return gelu
26+
27+
model_proto = bias_gelu_model.to_model_proto(
28+
input_types=[FLOAT[10], FLOAT[10]],
29+
output_types=[FLOAT[10]],
30+
ir_version=10,
31+
)
32+
model = ir.serde.deserialize_model(model_proto)
33+
optimize(model)
34+
35+
input = {
36+
"x": np.random.randn(10).astype(np.float32),
37+
"y": np.random.randn(10).astype(np.float32),
38+
}
39+
original_output = test_utils.ort_run("Original", model, input)
40+
41+
fuse_bias_gelu(model)
42+
remove_unused_nodes(model)
43+
44+
self.assertEqual(len(model.graph), 1)
45+
self.assertEqual(model.graph.node(0).op_type, "BiasGelu")
46+
47+
optimized_output = test_utils.ort_run("Optimized", model, input)
48+
test_utils.assert_allclose(original_output, optimized_output)
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()

onnxscript/rewriter/erfgelu.py renamed to onnxscript/rewriter/ort_fusions/erfgelu.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
# Licensed under the MIT License.
33
import math
44

5-
from onnxscript.rewriter import pattern
5+
from onnxscript.rewriter import _fusion_utils, pattern
66

77

88
# Pattern to match against
9-
def erf_gelu_pattern(op, x):
9+
def erf_gelu_pattern_1(op, x):
1010
# erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
1111
# half = pattern.Constant(0.5)
1212
# sqrt2 = pattern.Constant(1.4142)
@@ -19,9 +19,18 @@ def erf_gelu_pattern(op, x):
1919
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
2020

2121

22+
def erf_gelu_pattern_2(op, x):
23+
return x * (0.5 * (op.Erf(x / math.sqrt(2)) + 1.0))
24+
25+
2226
# Replacement
2327
def gelu(op, x):
2428
return op.Gelu(x, _domain="com.microsoft")
2529

2630

27-
rule = pattern.RewriteRule(erf_gelu_pattern, gelu)
31+
rule1 = pattern.RewriteRule(erf_gelu_pattern_1, gelu)
32+
rule2 = pattern.RewriteRule(erf_gelu_pattern_2, gelu)
33+
34+
rules = pattern.RewriteRuleSet([rule1, rule2])
35+
36+
fuse_erfgelu = _fusion_utils.apply_fusion_rules(rules)

onnxscript/rewriter/ort_fusions/skip_normalization.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,27 +47,66 @@ def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type):
4747
epsilon=epsilon,
4848
stash_type=stash_type,
4949
)
50-
return normalized
50+
return normalized, skip_sum
5151

5252

5353
def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type):
5454
if stash_type.value != 1: # FLOAT type
5555
return None
56-
normalized, _mean, _inv_std_var = op.SkipLayerNormalization(
56+
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
5757
input,
5858
skip,
5959
gamma,
6060
beta,
6161
epsilon=epsilon,
62-
_outputs=3,
62+
_outputs=4,
63+
_domain="com.microsoft",
64+
)
65+
return normalized, skip_sum
66+
67+
68+
# Fusion rule for Add + SkipLayerNormalization
69+
def _skip_layer_norm_add_bias_pattern(op, input, skip, gamma, beta, bias, epsilon, stash_type):
70+
bias_sum = op.Add(input, bias)
71+
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
72+
bias_sum,
73+
skip,
74+
gamma,
75+
beta,
76+
epsilon=epsilon,
77+
_outputs=4,
6378
_domain="com.microsoft",
6479
)
65-
return normalized
80+
return normalized, skip_sum
81+
6682

83+
def _skip_layer_normalization_add_bias(
84+
op, input, skip, gamma, beta, bias, epsilon, stash_type
85+
):
86+
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
87+
input,
88+
skip,
89+
gamma,
90+
beta,
91+
bias,
92+
epsilon=epsilon,
93+
_outputs=4,
94+
_domain="com.microsoft",
95+
)
96+
return normalized, skip_sum
97+
98+
99+
_skip_layer_rule = pattern.RewriteRule(
100+
_skip_layer_norm_pattern, _skip_layer_normalization, name="SkipLayerNorm"
101+
)
102+
_skip_layer_add_bias_rule = pattern.RewriteRule(
103+
_skip_layer_norm_add_bias_pattern,
104+
_skip_layer_normalization_add_bias,
105+
name="SkipLayerNormAddBias",
106+
)
67107

68-
_skip_layer_rule = pattern.RewriteRule(_skip_layer_norm_pattern, _skip_layer_normalization)
69108

70-
skip_layer_normalization_rules = [_skip_layer_rule]
109+
skip_layer_normalization_rules = [_skip_layer_rule, _skip_layer_add_bias_rule]
71110
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)
72111

73112

0 commit comments

Comments
 (0)