Skip to content

Attention mask for GQA fusion#2452

Merged
justinchuby merged 14 commits intomainfrom
rama/gqa-mask
Jul 21, 2025
Merged

Attention mask for GQA fusion#2452
justinchuby merged 14 commits intomainfrom
rama/gqa-mask

Conversation

@gramalingam
Copy link
Copy Markdown
Collaborator

@gramalingam gramalingam commented Jul 14, 2025

Expand the GQA fusion rule to handle attention mask better.

  • The patterns are extended to handle variations found in the attention-mask logic for various models.
  • It incorporates some optimizations of ModelBuilder that are arguably not general-purpose, but make assumptions about the intended use-case (which is the GenAI usage pattern).

Copilot AI and others added 10 commits July 11, 2025 16:11
…tern matching

Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
This change addresses @gramalingam's feedback to return the match object
(which includes success/failure status) instead of always returning None
when the initial pattern match fails. This provides more consistent API
behavior and makes failure information available when applicable.

- Changed PatternImpl.match() to return match object on line 161
- Updated RewriteRule.try_rewrite() to use "if not match:" instead of "if match is None:"
- Added test case to verify both None and failed MatchResult are handled correctly
- Backward compatible: None still returned for GenericPatternMatcher, failed MatchResult for SimplePatternMatcher

Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
…g failure

Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Jul 14, 2025

❌ 15 Tests Failed:

Tests completed Failed Passed Skipped
16636 15 16621 3672
View the top 3 failed test(s) by shortest run time
::onnxscript.tools.training_helper
Stack Traces | 0s run time
ImportError while importing test module '.../onnxscript/tools/training_helper.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/tools/training_helper.py:6: in <module>
    from torch.onnx import _OrtBackend, _OrtBackendOptions
E   ImportError: cannot import name '_OrtBackend' from 'torch.onnx' (.../onnxscript/onnxscript/.nox.../test_torch_nightly/lib/python3.11.../torch/onnx/__init__.py)
::onnxscript.tools.transformers_models.llama_test
Stack Traces | 0s run time
ImportError while importing test module '.../tools/transformers_models/llama_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../tools/transformers_models/llama_test.py:12: in <module>
    import onnxscript.tools.training_helper
onnxscript/tools/training_helper.py:6: in <module>
    from torch.onnx import _OrtBackend, _OrtBackendOptions
E   ImportError: cannot import name '_OrtBackend' from 'torch.onnx' (.../onnxscript/onnxscript/.nox.../test_torch_nightly/lib/python3.11.../torch/onnx/__init__.py)
::onnxscript.tools.transformers_models.mistral_test
Stack Traces | 0s run time
ImportError while importing test module '.../tools/transformers_models/mistral_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../tools/transformers_models/mistral_test.py:14: in <module>
    import onnxscript.tools.training_helper
onnxscript/tools/training_helper.py:6: in <module>
    from torch.onnx import _OrtBackend, _OrtBackendOptions
E   ImportError: cannot import name '_OrtBackend' from 'torch.onnx' (.../onnxscript/onnxscript/.nox.../test_torch_nightly/lib/python3.11.../torch/onnx/__init__.py)

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam gramalingam marked this pull request as ready for review July 18, 2025 16:47
@gramalingam gramalingam changed the title [DRAFT] Attention mask for GQA fusion Attention mask for GQA fusion Jul 18, 2025
@titaiwangms titaiwangms requested a review from Copilot July 18, 2025 20:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the GQA (Group Query Attention) fusion rule to better handle attention masks, particularly for causal masks with optional sliding window support. The changes simplify the fusion logic while making it more robust for various model patterns.

Key changes:

  • Refactored causal mask pattern to support sliding window attention and multiple model variations
  • Simplified GQA fusion by removing the separate GQACausalMask rule and consolidating logic
  • Updated position_ids handling to use a single parameter instead of separate query/key position IDs
Comments suppressed due to low confidence (1)

onnxscript/rewriter/ort_fusions/gqa.py:136

  • The sliding window functionality is implemented in the pattern (lines 69-74) but explicitly disabled in the check method. This creates untested code paths that could fail silently. Consider either removing the sliding window implementation or adding test coverage for when it's enabled.
        # TODO(rama) Sliding window: not yet supported.

Comment thread onnxscript/rewriter/ort_fusions/gqa.py
Comment thread onnxscript/rewriter/ort_fusions/gqa.py
Comment thread onnxscript/rewriter/ort_fusions/gqa.py
@titaiwangms
Copy link
Copy Markdown
Contributor

It incorporates some optimizations of ModelBuilder that are arguably not general-purpose, but make assumptions about the intended use-case (which is the GenAI usage pattern).

Does it make sense that we put a NOTE on these patterns/optimizations?

Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test, or it's already covered?

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam
Copy link
Copy Markdown
Collaborator Author

It incorporates some optimizations of ModelBuilder that are arguably not general-purpose, but make assumptions about the intended use-case (which is the GenAI usage pattern).

Does it make sense that we put a NOTE on these patterns/optimizations?

Yes, added a comment to explain this

@gramalingam
Copy link
Copy Markdown
Collaborator Author

Is it possible to add a test, or it's already covered?

I have tried it on Qwen2_5 and Mistralai and Phi2LM models. Agree, we should add test-cases. But wanted to merge it since Tomasso has changes that he will want to merge before his internship ends soon. I think some further refinements of the mask-pattern may be necessary anyway (for TinyLLM and/or Phi4).

@justinchuby justinchuby merged commit 3f2f7d3 into main Jul 21, 2025
23 of 30 checks passed
@justinchuby justinchuby deleted the rama/gqa-mask branch July 21, 2025 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

5 participants