Skip to content

Add more fusion test-cases (part 1)#2896

Merged
gramalingam merged 16 commits intomainfrom
rama/rewriter-tests
Apr 22, 2026
Merged

Add more fusion test-cases (part 1)#2896
gramalingam merged 16 commits intomainfrom
rama/rewriter-tests

Conversation

@gramalingam
Copy link
Copy Markdown
Collaborator

@gramalingam gramalingam commented Apr 18, 2026

Add more unit test cases for fusion rules. (Currently, these are tested via real-world models in the benchmark-suite elsewhere, but unit tests are missing for various fusions.)

  • Erfgelu fusion
  • MHA-Bias fusion
  • MHA-Scale fusion
  • RmsNormalization fusion
  • MHA fusion
  • Rotary Embedding
  • Skip Norm
  • Layer Norm

gramalingam and others added 3 commits April 17, 2026 23:10
Use @script with msft_op for model construction.
Test positive cases with numerical validation: all bias combos
(Q+K+V, Q-only, K-only, V-only, Q+K) verify original and fused
models produce equivalent outputs via ORT.
Test negative cases: no biases, INT32 dtype rejection, rank-2 shape mismatch.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Models now use symbolic dimension names in input_types/output_types
to better reflect real-world models, while concrete values (_B, _S)
are still used for numpy test data generation.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread onnxscript/rewriter/ort_fusions/mha_bias_test.py Fixed
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 18, 2026

Codecov Report

❌ Patch coverage is 67.49729% with 300 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.51%. Comparing base (c8f5f6a) to head (ad8a8b0).
⚠️ Report is 5 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/rewriter/ort_fusions/mha_unit_test.py 57.53% 58 Missing and 4 partials ⚠️
...rewriter/rules/fusion/_layer_norm_extended_test.py 50.00% 53 Missing and 1 partial ⚠️
...r/rules/fusion/_rms_normalization_extended_test.py 61.53% 39 Missing and 1 partial ⚠️
...ewriter/ort_fusions/rms_normalization_unit_test.py 62.50% 32 Missing and 1 partial ⚠️
onnxscript/rewriter/ort_fusions/erfgelu_test.py 66.66% 31 Missing and 1 partial ⚠️
...rewriter/ort_fusions/rotary_embedding_unit_test.py 64.47% 26 Missing and 1 partial ⚠️
onnxscript/rewriter/ort_fusions/mha_bias_test.py 80.76% 19 Missing and 1 partial ⚠️
...writer/ort_fusions/skip_normalization_unit_test.py 81.81% 17 Missing and 1 partial ⚠️
onnxscript/rewriter/ort_fusions/mha_scale_test.py 86.27% 11 Missing and 3 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2896      +/-   ##
==========================================
+ Coverage   72.48%   72.51%   +0.03%     
==========================================
  Files         241      250       +9     
  Lines       29915    30955    +1040     
  Branches     2935     2960      +25     
==========================================
+ Hits        21684    22448     +764     
- Misses       7233     7504     +271     
- Partials      998     1003       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

5 tests covering:
- Scalar float constant scale → fused into MHA scale attribute
- Integer scale constant → fused
- Existing MHA scale attribute → combined with external scale
- No Mul before MHA → no fusion (negative)
- Dynamic (non-constant) scale input → no fusion (negative)

All positive tests include numerical validation via ORT.
Uses symbolic dims ("B", "S") in input/output types.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread onnxscript/rewriter/ort_fusions/mha_scale_test.py Fixed
Use op.Constant(value=ir.tensor(...)) inside @script() functions to
define scale as a graph constant directly, instead of creating it as
a graph input then converting post-hoc. Simpler and more realistic.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread onnxscript/rewriter/ort_fusions/mha_scale_test.py Fixed
gramalingam and others added 4 commits April 18, 2026 03:05
Use tuples instead of lists for class-level _3D and _OUT constants.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Use "from onnxscript import values" instead of "import onnxscript"
to avoid mixing import and from-import for the same module.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
4 tests covering:
- Both Mul orderings: scale*normalized and normalized*scale (parameterized)
- Mixed-precision: fp16 input with fp32 compute via Cast
- Integer input dtype rejected (negative)

All positive tests include numerical validation via ORT.
Uses symbolic dims ("B", "S") and @script() model construction.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
5 structural tests covering 4 rule variants:
- Basic MHA with key transposed (BHSd format)
- Basic MHA with key not transposed (BSHd format)
- MHA with past key/value (has_past_present=True, 3 outputs)
- MHA with RotaryEmbedding on Q and K
- Rank-2 query shape rejection (negative)

Tests are structural-only (no ORT run) since the pattern requires
internal SDPA nodes (ai.onnxruntime._fusion) that ORT cannot execute.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread onnxscript/rewriter/ort_fusions/mha_unit_test.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/mha_unit_test.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/mha_unit_test.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/mha_unit_test.py Fixed
gramalingam and others added 3 commits April 18, 2026 03:29
The original negative test had a tuple-vs-list comparison bug:
get_ints() returns a tuple, so perm == [0,2,1,3] was always False,
meaning the corruption never happened. Fixed to compare with tuple.
Confirmed the fusion correctly rejects mismatched Transpose perms.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
22 new tests across 4 files:

rotary_embedding_unit_test.py (3 tests):
- Full rotary embedding pattern fusion
- Partial rotary embedding (adds rotary_embedding_dim attribute)
- 3D input rejection (negative)

skip_normalization_unit_test.py (8 tests):
- SkipRmsNorm: both Add orderings via OrValue (parameterized)
- SkipRmsNorm: post-add bias and pre-add bias variants
- SkipLayerNorm: no bias and post-add bias
- No skip Add (negative), rank-2 input (negative)

_rms_normalization_extended_test.py (5 tests):
- Both mul_order variants: scale*norm and norm*scale (parameterized)
- Mixed-precision: fp16 input with fp32 compute via Cast
- Double precision
- Integer input rejection (negative)

_layer_norm_extended_test.py (6 tests):
- OrValue: Pow(deviation,2) vs Mul(deviation,deviation)
- OrValue: Div(deviation,std_dev) vs Mul(deviation,Reciprocal)
- Both OrValue alternatives combined
- Div + bias fusion
- Double precision
- fp16 input rejection (negative)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread onnxscript/rewriter/rules/fusion/_layer_norm_extended_test.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/skip_normalization_unit_test.py Fixed
Comment thread onnxscript/rewriter/ort_fusions/skip_normalization_unit_test.py Fixed
Comment thread onnxscript/rewriter/rules/fusion/_layer_norm_extended_test.py Fixed
Comment thread onnxscript/rewriter/rules/fusion/_layer_norm_extended_test.py Fixed
gramalingam and others added 2 commits April 21, 2026 18:06
Use ONNX reference implementation (ORT lacks RMSNormalization kernel)
to verify original and fused models produce identical results for
float32 and fp16 tests. Double precision remains structural-only
since the reference impl doesn't support stash_type=DOUBLE.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
All 6 positive tests now verify original vs fused model outputs
match using ORT inference. Uses concrete dims for test data while
keeping the structural assertions.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread onnxscript/rewriter/ort_fusions/skip_normalization_unit_test.py Fixed
- Remove unused local variable 'input_data' in _layer_norm_extended_test.py
- Remove unused global '_EPS_F' in skip_normalization_unit_test.py

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@justinchuby
Copy link
Copy Markdown
Collaborator

Some tests failing

Serializing fused models with RMSNormalization requires onnx opset >= 23.
On older onnx versions, tests now fall back to structural checks only
(fusion count + op type assertions still run).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@gramalingam
Copy link
Copy Markdown
Collaborator Author

Some tests failing

Addressed. This is due to CIs that use onnx==1.17. I wonder why we have so many of these (in comparison to later ones). According to copilot, this is what we have:

CI onnx version matrix

Session onnx version Notes
test (py310, py311, py312 × 3 OS) onnx==1.17 (opset 22) Numerical validation for
RMSNormalization skipped
test_ort_nightly (py311 × 3 OS) onnx==1.17 (opset 22) Numerical validation for RMSNormalization
skipped
test_onnx_ir_git (py311 × 3 OS) onnx==1.17 (opset 22) Numerical validation for RMSNormalization
skipped
test_onnx_weekly (py311 × 3 OS) onnx-weekly ~1.22 (opset 25) Full numerical validation runs
test_torch_nightly (py311 × 3 OS) onnx-weekly ~1.22 (opset 25) Full numerical validation runs

@gramalingam gramalingam merged commit 1911741 into main Apr 22, 2026
30 of 33 checks passed
@gramalingam gramalingam deleted the rama/rewriter-tests branch April 22, 2026 01:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

3 participants