Skip to content

Commit 5461135

Browse files
gramalingamCopilot
andcommitted
Add unit tests for rotary embedding, skip norm, rms norm, layer norm
22 new tests across 4 files: rotary_embedding_unit_test.py (3 tests): - Full rotary embedding pattern fusion - Partial rotary embedding (adds rotary_embedding_dim attribute) - 3D input rejection (negative) skip_normalization_unit_test.py (8 tests): - SkipRmsNorm: both Add orderings via OrValue (parameterized) - SkipRmsNorm: post-add bias and pre-add bias variants - SkipLayerNorm: no bias and post-add bias - No skip Add (negative), rank-2 input (negative) _rms_normalization_extended_test.py (5 tests): - Both mul_order variants: scale*norm and norm*scale (parameterized) - Mixed-precision: fp16 input with fp32 compute via Cast - Double precision - Integer input rejection (negative) _layer_norm_extended_test.py (6 tests): - OrValue: Pow(deviation,2) vs Mul(deviation,deviation) - OrValue: Div(deviation,std_dev) vs Mul(deviation,Reciprocal) - Both OrValue alternatives combined - Div + bias fusion - Double precision - fp16 input rejection (negative) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent c6beadb commit 5461135

File tree

4 files changed

+746
-0
lines changed

4 files changed

+746
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Unit tests for RotaryEmbeddingFusion and PartialRotaryEmbeddingFusion rules.
5+
6+
RotaryEmbeddingFusion matches: x * cos + rotate_half(x) * sin
7+
and rewrites to RotaryEmbedding(x, cos, sin, ...).
8+
9+
PartialRotaryEmbeddingFusion matches: Concat(RotaryEmbedding(Slice(x)), Slice(x))
10+
and adds rotary_embedding_dim attribute to the RotaryEmbedding op.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import unittest
16+
17+
import numpy as np
18+
import onnx_ir as ir
19+
20+
from onnxscript import FLOAT, script, values
21+
from onnxscript import opset18 as op
22+
from onnxscript.optimizer import optimize
23+
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
24+
fuse_partial_rotary_embedding,
25+
fuse_rotary_embedding,
26+
)
27+
28+
fusion_op = values.Opset("ai.onnxruntime._fusion", 1)
29+
msft_op = values.Opset("com.microsoft", 1)
30+
31+
_B, _H, _S, _Dh = 2, 4, 8, 8
32+
_HALF = _Dh // 2
33+
34+
# Constants for slice boundaries
35+
_ZERO = ir.tensor(np.array([0], dtype=np.int64))
36+
_HALF_TENSOR = ir.tensor(np.array([_HALF], dtype=np.int64))
37+
_HEAD_SIZE_TENSOR = ir.tensor(np.array([_Dh], dtype=np.int64))
38+
_MAX_INT = ir.tensor(np.array([9223372036854775807], dtype=np.int64))
39+
40+
41+
# --- Full rotary embedding pattern ---
42+
43+
44+
@script()
45+
def _rotary_full(x, cos, sin):
46+
"""X * cos + rotate_half(x) * sin — standard non-interleaved pattern."""
47+
start_0 = op.Constant(value=_ZERO)
48+
end_half = op.Constant(value=_HALF_TENSOR)
49+
start_half = op.Constant(value=_HALF_TENSOR)
50+
end_full = op.Constant(value=_HEAD_SIZE_TENSOR)
51+
# rotate_half: concat(-x2, x1)
52+
x1 = op.Slice(x, start_0, end_half, [3], [1])
53+
x2 = op.Slice(x, start_half, end_full, [3], [1])
54+
neg_x2 = op.Neg(x2)
55+
rotated = op.Concat(neg_x2, x1, axis=-1)
56+
return op.Add(op.Mul(x, cos), op.Mul(rotated, sin))
57+
58+
59+
# --- Partial rotary embedding pattern ---
60+
61+
62+
@script()
63+
def _partial_rotary(x, cos, sin, position_ids):
64+
"""Slice → RotaryEmbedding on first half, concat with second half."""
65+
end1 = op.Constant(value=_HALF_TENSOR)
66+
start2 = op.Constant(value=_HALF_TENSOR)
67+
max_end = op.Constant(value=_MAX_INT)
68+
x_part1 = op.Slice(x, [0], end1, [3], [1])
69+
x_part2 = op.Slice(x, start2, max_end, [3], [1])
70+
x_part1_rope = msft_op.RotaryEmbedding(x_part1, position_ids, cos, sin, interleaved=0)
71+
return op.Concat(x_part1_rope, x_part2, axis=-1)
72+
73+
74+
# --- Negative: 3D input instead of 4D ---
75+
76+
77+
@script()
78+
def _rotary_3d_input(x, cos, sin):
79+
"""3D input — should fail the 4D check."""
80+
start_0 = op.Constant(value=_ZERO)
81+
end_half = op.Constant(value=_HALF_TENSOR)
82+
start_half = op.Constant(value=_HALF_TENSOR)
83+
end_full = op.Constant(value=_HEAD_SIZE_TENSOR)
84+
x1 = op.Slice(x, start_0, end_half, [3], [1])
85+
x2 = op.Slice(x, start_half, end_full, [3], [1])
86+
neg_x2 = op.Neg(x2)
87+
rotated = op.Concat(neg_x2, x1, axis=-1)
88+
return op.Add(op.Mul(x, cos), op.Mul(rotated, sin))
89+
90+
91+
class RotaryEmbeddingFusionTest(unittest.TestCase):
92+
"""Unit tests for RotaryEmbeddingFusion rule."""
93+
94+
def _build(self, script_fn, input_types, output_types) -> ir.Model:
95+
model_proto = script_fn.to_model_proto(
96+
input_types=input_types, output_types=output_types
97+
)
98+
model = ir.serde.deserialize_model(model_proto)
99+
optimize(model)
100+
return model
101+
102+
def _count_op(self, model: ir.Model, op_type: str, domain: str = "") -> int:
103+
return sum(1 for n in model.graph if n.op_type == op_type and n.domain == domain)
104+
105+
# --- Positive tests ---
106+
107+
def test_full_rotary_fuses(self):
108+
"""Standard full rotary embedding pattern → fuses to RotaryEmbedding op."""
109+
model = self._build(
110+
_rotary_full,
111+
input_types=[
112+
FLOAT[_B, _H, "S", _Dh],
113+
FLOAT[_B, _H, "S", _Dh],
114+
FLOAT[_B, _H, "S", _Dh],
115+
],
116+
output_types=[FLOAT[_B, _H, "S", _Dh]],
117+
)
118+
count = fuse_rotary_embedding(model)
119+
self.assertEqual(count, 1)
120+
self.assertEqual(self._count_op(model, "RotaryEmbedding", "ai.onnxruntime._fusion"), 1)
121+
# Pattern ops should be consumed
122+
self.assertEqual(self._count_op(model, "Neg"), 0)
123+
124+
def test_partial_rotary_fuses(self):
125+
"""Partial rotary: Slice → RotaryEmbedding on first part, concat with rest."""
126+
model = self._build(
127+
_partial_rotary,
128+
input_types=[
129+
FLOAT[_B, _H, "S", _Dh],
130+
FLOAT["S", _HALF],
131+
FLOAT["S", _HALF],
132+
FLOAT[_B, "S"],
133+
],
134+
output_types=[FLOAT[_B, _H, "S", _Dh]],
135+
)
136+
count = fuse_partial_rotary_embedding(model)
137+
self.assertEqual(count, 1)
138+
# RotaryEmbedding should still exist but now with rotary_embedding_dim
139+
rope_nodes = [
140+
n
141+
for n in model.graph
142+
if n.op_type == "RotaryEmbedding" and n.domain == "com.microsoft"
143+
]
144+
self.assertEqual(len(rope_nodes), 1)
145+
self.assertIn("rotary_embedding_dim", rope_nodes[0].attributes)
146+
self.assertEqual(rope_nodes[0].attributes["rotary_embedding_dim"].value, _HALF)
147+
148+
# --- Negative tests ---
149+
150+
def test_3d_input_no_fusion(self):
151+
"""3D input (missing batch or head dim) → check rejects."""
152+
model_proto = _rotary_3d_input.to_model_proto(
153+
input_types=[
154+
FLOAT[_H, "S", _Dh],
155+
FLOAT[_H, "S", _Dh],
156+
FLOAT[_H, "S", _Dh],
157+
],
158+
output_types=[FLOAT[_H, "S", _Dh]],
159+
)
160+
model = ir.serde.deserialize_model(model_proto)
161+
# Skip optimize — 3D shapes cause shape inference errors which is expected
162+
count = fuse_rotary_embedding(model)
163+
self.assertEqual(count, 0)
164+
165+
166+
if __name__ == "__main__":
167+
unittest.main()
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Unit tests for SkipRmsNormFusion and SkipLayerNormFusion rules.
5+
6+
SkipRmsNormFusion: Add(input, skip) → SimplifiedLayerNormalization →
7+
SkipSimplifiedLayerNormalization (com.microsoft).
8+
9+
SkipLayerNormFusion: Add(input, skip) → LayerNormalization →
10+
SkipLayerNormalization (com.microsoft).
11+
12+
Covers: no bias, post-add bias, pre-add bias variants; negative tests.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import unittest
18+
19+
import numpy as np
20+
import onnx_ir as ir
21+
from parameterized import parameterized
22+
23+
from onnxscript import FLOAT, script
24+
from onnxscript import opset18 as op
25+
from onnxscript.optimizer import optimize
26+
from onnxscript.rewriter.ort_fusions.skip_normalization import (
27+
fuse_skip_layer_normalization,
28+
fuse_skip_rms_normalization,
29+
)
30+
31+
_B, _S, _D = 2, 8, 16
32+
_EPS_F = ir.tensor(np.array([1e-6], dtype=np.float32))
33+
34+
35+
# ========== Skip RMS Norm patterns ==========
36+
37+
38+
@script()
39+
def _skip_rms_no_bias(input, skip, gamma):
40+
skip_sum = op.Add(input, skip)
41+
return op.SimplifiedLayerNormalization(
42+
skip_sum, gamma, axis=-1, epsilon=1e-6, stash_type=1
43+
)
44+
45+
46+
@script()
47+
def _skip_rms_no_bias_reversed(input, skip, gamma):
48+
"""Skip + input order reversed (OrValue alternative)."""
49+
skip_sum = op.Add(skip, input)
50+
return op.SimplifiedLayerNormalization(
51+
skip_sum, gamma, axis=-1, epsilon=1e-6, stash_type=1
52+
)
53+
54+
55+
@script()
56+
def _skip_rms_post_bias(input, skip, gamma, bias):
57+
skip_sum = op.Add(input, skip)
58+
skip_sum_biased = op.Add(skip_sum, bias)
59+
return op.SimplifiedLayerNormalization(
60+
skip_sum_biased, gamma, axis=-1, epsilon=1e-6, stash_type=1
61+
)
62+
63+
64+
@script()
65+
def _skip_rms_pre_bias(input, skip, gamma, bias):
66+
input_biased = op.Add(input, bias)
67+
skip_sum = op.Add(input_biased, skip)
68+
return op.SimplifiedLayerNormalization(
69+
skip_sum, gamma, axis=-1, epsilon=1e-6, stash_type=1
70+
)
71+
72+
73+
# ========== Skip Layer Norm patterns ==========
74+
75+
76+
@script()
77+
def _skip_ln_no_bias(input, skip, gamma, beta):
78+
skip_sum = op.Add(input, skip)
79+
return op.LayerNormalization(skip_sum, gamma, beta, axis=-1, epsilon=1e-6, stash_type=1)
80+
81+
82+
@script()
83+
def _skip_ln_post_bias(input, skip, gamma, beta, bias):
84+
skip_sum = op.Add(input, skip)
85+
skip_sum_biased = op.Add(skip_sum, bias)
86+
return op.LayerNormalization(
87+
skip_sum_biased, gamma, beta, axis=-1, epsilon=1e-6, stash_type=1
88+
)
89+
90+
91+
# ========== Negative patterns ==========
92+
93+
94+
@script()
95+
def _skip_rms_no_add(input, gamma):
96+
"""No skip addition at all — just SimplifiedLayerNormalization."""
97+
return op.SimplifiedLayerNormalization(input, gamma, axis=-1, epsilon=1e-6, stash_type=1)
98+
99+
100+
class SkipNormalizationTest(unittest.TestCase):
101+
"""Unit tests for SkipRmsNormFusion and SkipLayerNormFusion."""
102+
103+
def _build(self, script_fn, input_types, output_types) -> ir.Model:
104+
model_proto = script_fn.to_model_proto(
105+
input_types=input_types, output_types=output_types
106+
)
107+
model = ir.serde.deserialize_model(model_proto)
108+
optimize(model)
109+
return model
110+
111+
def _count_op(self, model: ir.Model, op_type: str, domain: str = "") -> int:
112+
return sum(1 for n in model.graph if n.op_type == op_type and n.domain == domain)
113+
114+
_3D = FLOAT["B", "S", _D]
115+
_1D = FLOAT[_D]
116+
117+
# ---- Skip RMS Norm positive tests ----
118+
119+
@parameterized.expand(
120+
[
121+
("input_plus_skip", _skip_rms_no_bias),
122+
("skip_plus_input", _skip_rms_no_bias_reversed),
123+
]
124+
)
125+
def test_skip_rms_no_bias(self, _name, script_fn):
126+
"""Skip + Input (both orderings) → SkipSimplifiedLayerNormalization."""
127+
model = self._build(
128+
script_fn,
129+
input_types=[self._3D, self._3D, self._1D],
130+
output_types=[self._3D],
131+
)
132+
count = fuse_skip_rms_normalization(model)
133+
self.assertGreater(count, 0)
134+
self.assertEqual(
135+
self._count_op(model, "SkipSimplifiedLayerNormalization", "com.microsoft"), 1
136+
)
137+
self.assertEqual(self._count_op(model, "SimplifiedLayerNormalization"), 0)
138+
139+
def test_skip_rms_post_bias(self):
140+
"""(Input + Skip) + Bias → fuses with bias."""
141+
model = self._build(
142+
_skip_rms_post_bias,
143+
input_types=[self._3D, self._3D, self._1D, self._1D],
144+
output_types=[self._3D],
145+
)
146+
count = fuse_skip_rms_normalization(model)
147+
self.assertGreater(count, 0)
148+
self.assertEqual(
149+
self._count_op(model, "SkipSimplifiedLayerNormalization", "com.microsoft"), 1
150+
)
151+
152+
def test_skip_rms_pre_bias(self):
153+
"""(Input + Bias) + Skip → fuses with pre-add bias."""
154+
model = self._build(
155+
_skip_rms_pre_bias,
156+
input_types=[self._3D, self._3D, self._1D, self._1D],
157+
output_types=[self._3D],
158+
)
159+
count = fuse_skip_rms_normalization(model)
160+
self.assertGreater(count, 0)
161+
self.assertEqual(
162+
self._count_op(model, "SkipSimplifiedLayerNormalization", "com.microsoft"), 1
163+
)
164+
165+
# ---- Skip Layer Norm positive tests ----
166+
167+
def test_skip_ln_no_bias(self):
168+
"""Skip + Input → SkipLayerNormalization."""
169+
model = self._build(
170+
_skip_ln_no_bias,
171+
input_types=[self._3D, self._3D, self._1D, self._1D],
172+
output_types=[self._3D],
173+
)
174+
count = fuse_skip_layer_normalization(model)
175+
self.assertGreater(count, 0)
176+
self.assertEqual(self._count_op(model, "SkipLayerNormalization", "com.microsoft"), 1)
177+
self.assertEqual(self._count_op(model, "LayerNormalization"), 0)
178+
179+
def test_skip_ln_post_bias(self):
180+
"""(Input + Skip) + Bias → fuses with bias."""
181+
model = self._build(
182+
_skip_ln_post_bias,
183+
input_types=[self._3D, self._3D, self._1D, self._1D, self._1D],
184+
output_types=[self._3D],
185+
)
186+
count = fuse_skip_layer_normalization(model)
187+
self.assertGreater(count, 0)
188+
self.assertEqual(self._count_op(model, "SkipLayerNormalization", "com.microsoft"), 1)
189+
190+
# ---- Negative tests ----
191+
192+
def test_no_skip_add_no_fusion(self):
193+
"""No Add before norm → rule should not match."""
194+
model = self._build(
195+
_skip_rms_no_add,
196+
input_types=[self._3D, self._1D],
197+
output_types=[self._3D],
198+
)
199+
count = fuse_skip_rms_normalization(model)
200+
self.assertEqual(count, 0)
201+
self.assertEqual(
202+
self._count_op(model, "SkipSimplifiedLayerNormalization", "com.microsoft"), 0
203+
)
204+
205+
def test_rank2_input_no_fusion(self):
206+
"""Rank-2 input [S, D] → shape check rejects (expects 3D)."""
207+
model = self._build(
208+
_skip_rms_no_bias,
209+
input_types=[FLOAT["S", _D], FLOAT["S", _D], self._1D],
210+
output_types=[FLOAT["S", _D]],
211+
)
212+
count = fuse_skip_rms_normalization(model)
213+
self.assertEqual(count, 0)
214+
215+
216+
if __name__ == "__main__":
217+
unittest.main()

0 commit comments

Comments
 (0)