Add remaining high-priority rewriter extended tests#2899
Add remaining high-priority rewriter extended tests#2899gramalingam wants to merge 4 commits intomainfrom
Conversation
…les) Extended test coverage for: - CastConstantOfShape: dtype variants (int32, float64, bfloat16), same-dtype positive - FuseConvAffine: non-constant weight/bias/scale negatives, padded pre-conv negative - RedundantScatterND: axis=0 dynamic positive, partial indices negative, shape mismatch negative - PartialRotaryEmbedding23: positive with dim attr validation, boundary mismatch negative, already-has-attr negative - GQA (rules/fusion): wrong group count negative, different head config positive - GQA (ort_fusions): no RotaryEmbedding negative - GQA packed QKV: misaligned slice boundaries negative - cos_sin_cache: non-constant inv_freq negative, constant inv_freq positive with ORT validation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Also adds a positive test with numerical validation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…alue_info Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add _make_partial_rotary_script(mismatched) to _rotary_embedding_models.py that uses a traced-if to generate matching or mismatched slice boundaries. This replaces the graph-mutation approach in the test with a clean model built directly from @script(). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
| import onnx_ir as ir | ||
| from packaging import version | ||
|
|
||
| import onnxscript |
| import numpy | ||
| import onnx_ir as ir | ||
|
|
||
| import onnxscript |
| import onnx_ir as ir | ||
| import onnx_ir.passes.common.shape_inference as shape_inference | ||
|
|
||
| import onnxscript |
| import onnx_ir as ir | ||
| import onnx_ir.passes.common.shape_inference as shape_inference | ||
|
|
||
| import onnxscript |
❌ 1 Tests Failed:
View the top 1 failed test(s) by shortest run time
View the full list of 1 ❄️ flaky test(s)
To view more test analytics, go to the Test Analytics Dashboard |
There was a problem hiding this comment.
Pull request overview
This PR expands the ONNXScript rewriter “extended” test suite to cover additional high-priority fusion/rewrite gaps, and introduces a small model-generator helper to build cleaner negative rotary-embedding cases.
Changes:
- Added new extended test modules for multiple rewriter rules (fusion, common, and ort_fusions).
- Added negative/positive scenarios including boundary mismatches and prerequisite-missing cases.
- Extended the rotary-embedding model helpers with a script generator for partial rotary patterns.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxscript/rewriter/rules/fusion/_rotary_embedding_extended_test.py | Adds extended tests for PartialRotaryEmbedding23 (positive + negative cases). |
| onnxscript/rewriter/rules/fusion/_gqa_extended_test.py | Adds extended tests for fusion GQA rule (negative wrong-group + a positive alt config). |
| onnxscript/rewriter/rules/common/_redundant_scatter_nd_extended_test.py | Adds extended tests for redundant ScatterND elimination. |
| onnxscript/rewriter/rules/common/_fuse_conv_affine_extended_test.py | Adds extended tests for Conv/Affine fusion rules (positive + multiple negatives). |
| onnxscript/rewriter/rules/common/_cast_constant_of_shape_extended_test.py | Adds extended dtype-coverage tests for Cast+ConstantOfShape fusion. |
| onnxscript/rewriter/ort_fusions/gqa_packed_qkv_extended_test.py | Adds extended negative test for packed-QKV GQA fusion. |
| onnxscript/rewriter/ort_fusions/gqa_extended_test.py | Adds an extended negative test for ort_fusions GQA (no RotaryEmbedding prerequisite). |
| onnxscript/rewriter/ort_fusions/cos_sin_cache_extended_test.py | Adds extended tests for cos/sin cache fusion (dynamic inv_freq negative + sanity positive). |
| onnxscript/rewriter/models/_rotary_embedding_models.py | Adds _make_partial_rotary_script(mismatched=...) generator to support negative test construction. |
| def model_fn(data: FLOAT[8, 16], updates: FLOAT[8, 16]) -> FLOAT[8, 16]: | ||
| axis = op.Constant(value_int=0) | ||
| shape = op.Shape(data, start=0) | ||
| dim = op.Gather(shape, axis, axis=0) | ||
| full_range = op.Range(0, dim, 1) | ||
| full_range_2d = op.Unsqueeze(full_range, [-1]) | ||
| scattered = op.ScatterND(data, full_range_2d, updates, reduction="none") | ||
| return scattered | ||
|
|
||
| model_proto = model_fn.to_model_proto() | ||
| model = ir.serde.deserialize_model(model_proto) | ||
| onnx_check(model) | ||
| shape_inference(model) | ||
| onnxscript.optimizer.fold_constants(model) | ||
| count = _redundant_scatter_nd.rules.apply_to_model(model) | ||
| self.assertEqual(count, 1) |
There was a problem hiding this comment.
In test_dynamic_indices_axis_0, the input types are fully static (FLOAT[8, 16]) and the test runs fold_constants() before applying the rewrite. With fully-known shapes, constant folding can precompute Shape/Gather/Range/Unsqueeze, removing the dynamic-indices subgraph that ScatterAllDynamic matches (and ScatterAllStatic may not match due to reduction="none"). Consider making the updated axis dimension symbolic (e.g., FLOAT["N", 16]) or skipping constant-folding here so the dynamic-pattern nodes remain and the rule is actually exercised reliably.
| Adds coverage for: different axis (dynamic), partial scatter (negative), | ||
| shape mismatch (negative for static), and reduction="add" (negative for static). |
There was a problem hiding this comment.
The module docstring claims coverage for reduction="add" and an additional negative case, but this file currently only includes the axis=0 positive test plus two static negative tests (partial indices, shape mismatch). Either add the missing test(s) or update the docstring to reflect the actual coverage.
| Adds coverage for: different axis (dynamic), partial scatter (negative), | |
| shape mismatch (negative for static), and reduction="add" (negative for static). | |
| Adds coverage for: dynamic full-range scatter on axis=0 (positive), | |
| partial scatter (negative for static), and shape mismatch (negative for static). |
|
|
||
| """Extended tests for GQA ort_fusions fusion. | ||
|
|
||
| Adds coverage for: wrong expand shape (negative — group count mismatch). |
There was a problem hiding this comment.
The file docstring says this adds coverage for a “wrong expand shape (negative — group count mismatch)”, but the only test in this file is test_no_rotary_embedding_no_gqa_fusion (no RotaryEmbedding prerequisite). Please update the docstring (or rename the test/add the intended group-mismatch case) so the stated coverage matches what the test actually validates.
| Adds coverage for: wrong expand shape (negative — group count mismatch). | |
| Adds coverage for: missing RotaryEmbedding (negative — no GQA fusion). |
Summary
Add 22 extended tests across 8 test files covering the remaining high-priority gaps identified in the rewriter test gap analysis.
New test files
_cast_constant_of_shape_extended_test.py_cast_constant_of_shape_fuse_conv_affine_extended_test.py_fuse_conv_affine_redundant_scatter_nd_extended_test.py_redundant_scatter_nd_rotary_embedding_extended_test.py(rules/fusion)PartialRotaryEmbedding23Fusion_gqa_extended_test.py(rules/fusion)GroupQueryAttentionFusiongqa_extended_test.py(ort_fusions)gqaORT fusiongqa_packed_qkv_extended_test.py(ort_fusions)gqapacked QKVcos_sin_cache_extended_test.py(ort_fusions)cos_sin_cacheKey patterns used
@script()API withop.Constant(value=...)for embedded constants"B","S") where appropriateModified existing file
_rotary_embedding_models.py: Added_make_partial_rotary_script(mismatched)generator for clean negative test constructionFollow-up to PR #2896.