Skip to content

Commit 1911741

Browse files
gramalingamCopilot
andauthored
Add more fusion test-cases (part 1) (#2896)
Add more unit test cases for fusion rules. (Currently, these are tested via real-world models in the benchmark-suite elsewhere, but unit tests are missing for various fusions.) * Erfgelu fusion * MHA-Bias fusion * MHA-Scale fusion * RmsNormalization fusion * MHA fusion * Rotary Embedding * Skip Norm * Layer Norm --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent eae7b8d commit 1911741

File tree

9 files changed

+1828
-0
lines changed

9 files changed

+1828
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import math
5+
import unittest
6+
7+
import numpy as np
8+
import onnx_ir as ir
9+
10+
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
11+
from onnxscript import FLOAT, script
12+
from onnxscript import opset18 as op
13+
from onnxscript.optimizer import optimize, remove_unused_nodes
14+
from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu
15+
16+
_SQRT_TWO = math.sqrt(2.0)
17+
18+
19+
class ErfGeluFusionTest(unittest.TestCase):
20+
"""Tests for erf-based GELU fusion patterns in erfgelu.py.
21+
22+
Pattern 1: 0.5 * (x * (erf(x / sqrt(2)) + 1))
23+
Pattern 2: x * (0.5 * (erf(x / sqrt(2)) + 1))
24+
"""
25+
26+
def _check_fusion(self, model, input):
27+
original_output = test_utils.ort_run("Original", model, input)
28+
fuse_erfgelu(model)
29+
remove_unused_nodes(model)
30+
self.assertEqual(len(model.graph), 1)
31+
self.assertEqual(model.graph.node(0).op_type, "Gelu")
32+
self.assertEqual(model.graph.node(0).domain, "com.microsoft")
33+
optimized_output = test_utils.ort_run("Optimized", model, input)
34+
test_utils.assert_allclose(original_output, optimized_output)
35+
36+
def _check_no_fusion(self, model):
37+
node_count_before = len(model.graph)
38+
fuse_erfgelu(model)
39+
remove_unused_nodes(model)
40+
self.assertEqual(len(model.graph), node_count_before)
41+
self.assertTrue(
42+
all(node.op_type != "Gelu" for node in model.graph),
43+
"Gelu node should not be present after failed fusion",
44+
)
45+
46+
def _build_model(self, script_fn, shape):
47+
model_proto = script_fn.to_model_proto(
48+
input_types=[FLOAT[shape]], output_types=[FLOAT[shape]]
49+
)
50+
model = ir.serde.deserialize_model(model_proto)
51+
optimize(model)
52+
return model
53+
54+
def test_pattern1_half_times_x_times_erf_plus_one(self):
55+
"""Pattern 1: 0.5 * (x * (erf(x / sqrt(2)) + 1))"""
56+
57+
@script()
58+
def erf_gelu_p1(x):
59+
t1 = op.Div(x, _SQRT_TWO)
60+
t2 = op.Erf(t1)
61+
t3 = op.Add(t2, 1.0)
62+
t4 = op.Mul(x, t3)
63+
return op.Mul(0.5, t4)
64+
65+
model = self._build_model(erf_gelu_p1, 10)
66+
input = {"x": np.random.randn(10).astype(np.float32)}
67+
self._check_fusion(model, input)
68+
69+
def test_pattern2_x_times_half_times_erf_plus_one(self):
70+
"""Pattern 2: x * (0.5 * (erf(x / sqrt(2)) + 1))"""
71+
72+
@script()
73+
def erf_gelu_p2(x):
74+
t1 = op.Div(x, _SQRT_TWO)
75+
t2 = op.Erf(t1)
76+
t3 = op.Add(t2, 1.0)
77+
t4 = op.Mul(0.5, t3)
78+
return op.Mul(x, t4)
79+
80+
model = self._build_model(erf_gelu_p2, 10)
81+
input = {"x": np.random.randn(10).astype(np.float32)}
82+
self._check_fusion(model, input)
83+
84+
def test_multidimensional_input(self):
85+
"""Verify fusion works with 3D inputs (batch, seq, hidden)."""
86+
87+
@script()
88+
def erf_gelu_3d(x):
89+
t1 = op.Div(x, _SQRT_TWO)
90+
t2 = op.Erf(t1)
91+
t3 = op.Add(t2, 1.0)
92+
t4 = op.Mul(x, t3)
93+
return op.Mul(0.5, t4)
94+
95+
model = self._build_model(erf_gelu_3d, (2, 4, 8))
96+
input = {"x": np.random.randn(2, 4, 8).astype(np.float32)}
97+
self._check_fusion(model, input)
98+
99+
def test_no_fusion_without_erf(self):
100+
"""Replacing Erf with Tanh should not match the erf-gelu pattern."""
101+
102+
@script()
103+
def not_erf_gelu(x):
104+
t1 = op.Div(x, _SQRT_TWO)
105+
t2 = op.Tanh(t1)
106+
t3 = op.Add(t2, 1.0)
107+
t4 = op.Mul(x, t3)
108+
return op.Mul(0.5, t4)
109+
110+
model = self._build_model(not_erf_gelu, 10)
111+
self._check_no_fusion(model)
112+
113+
def test_no_fusion_wrong_divisor(self):
114+
"""Using a divisor other than sqrt(2) should not match."""
115+
116+
@script()
117+
def wrong_divisor(x):
118+
t1 = op.Div(x, 2.0)
119+
t2 = op.Erf(t1)
120+
t3 = op.Add(t2, 1.0)
121+
t4 = op.Mul(x, t3)
122+
return op.Mul(0.5, t4)
123+
124+
model = self._build_model(wrong_divisor, 10)
125+
self._check_no_fusion(model)
126+
127+
def test_no_fusion_wrong_scale(self):
128+
"""Using 0.3 instead of 0.5 should not match."""
129+
130+
@script()
131+
def wrong_scale(x):
132+
t1 = op.Div(x, _SQRT_TWO)
133+
t2 = op.Erf(t1)
134+
t3 = op.Add(t2, 1.0)
135+
t4 = op.Mul(x, t3)
136+
return op.Mul(0.3, t4)
137+
138+
model = self._build_model(wrong_scale, 10)
139+
self._check_no_fusion(model)
140+
141+
142+
if __name__ == "__main__":
143+
unittest.main()

0 commit comments

Comments
 (0)