Add BiasGelu, Erfgelu and SkipLayerNormalization fusions#2222
Merged
shubhambhokare1 merged 5 commits intomainfrom Apr 24, 2025
Merged
Add BiasGelu, Erfgelu and SkipLayerNormalization fusions#2222shubhambhokare1 merged 5 commits intomainfrom
shubhambhokare1 merged 5 commits intomainfrom
Conversation
5 tasks
❌ 4 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
Contributor
There was a problem hiding this comment.
Pull Request Overview
This PR adds new fusion rules for BiasGelu, Erfgelu, and updates SkipLayerNormalization fusion to support an additional output. Key changes include:
- Updating SkipLayerNormalization to return a new output (skip_sum) and adding a corresponding fusion rule for Add+SkipLayerNormalization.
- Implementing a new BiasGelu fusion along with its test.
- Updating Erfgelu rewriting with two pattern functions for better matching capability.
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| onnxscript/rewriter/ort_fusions/skip_normalization.py | Modified SkipLayerNormalization to return an extra output and added support for bias fusion. |
| onnxscript/rewriter/ort_fusions/bias_gelu_test.py | Added a unit test to verify the BiasGelu fusion. |
| onnxscript/rewriter/ort_fusions/bias_gelu.py | Implemented the BiasGelu fusion rule. |
| onnxscript/rewriter/ort_fusions/_core.py | Updated the fusion count to include BiasGelu fusion. |
| onnxscript/rewriter/erfgelu.py | Introduced two Erfgelu pattern rules to enhance rewrite flexibility. |
| onnxscript/rewriter/init.py | Updated to include the new Erfgelu fusion rule. |
Comments suppressed due to low confidence (1)
onnxscript/rewriter/erfgelu.py:9
- [nitpick] The current names 'erf_gelu_pattern_1' and 'erf_gelu_pattern_2' are not very descriptive. Consider renaming them to indicate the specific matching or transformation behavior they implement.
def erf_gelu_pattern_1(op, x):
justinchuby
reviewed
Apr 23, 2025
gramalingam
reviewed
Apr 24, 2025
gramalingam
reviewed
Apr 24, 2025
gramalingam
approved these changes
Apr 24, 2025
Collaborator
gramalingam
left a comment
There was a problem hiding this comment.
Just a minor suggestion to use onnxscript for test-case model-proto
82a891f to
d65da68
Compare
justinchuby
approved these changes
Apr 24, 2025
shubhambhokare1
added a commit
that referenced
this pull request
May 7, 2025
Add fusion rules to support the optimization of Whisper models.
Fusions added:
- Basic Fusions:
* additional pattern for erfgelu [moved to #2222]
- SkipLayerNorm:
* #2259
* Fusion patterns where skip_sum is also an output
* Bias + SkipLayerNorm -> SkipLayerNorm (with bias) [moved to #2222]
- BiasGelu Fusion [moved to #2222]
- SDPA:
* Support for pattern where only q is pre-scaled
- MHA:
* Patterns with/without past/present keys/values
* Patterns with non-rotary embeddings
* Patterns with/without mask
* Patterns with cross-attention (only for past key/value patterns)
- MHA Bias Fusion:
* Bias was offloaded to Attention fusion previously, this fusion fixes
that
- Attention:
* Patterns where Q, K and V do not come from slicing
TODO:
- [x] Fix SDPA singular prescale case, due to lost shape information
- [x] - Enable check conditions when #2210 is merged
- [x] - Improve/Rewrite whisper model test case to be similar to that of
smollm (for eg)
- [x] - Fix failing test cases to account for new patterns
- [x] - Add isolated test cases for new fusions like BiasGelu,
SkipLayerNorm etc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request introduces new fusion patterns and enhancements to the ONNXScript rewriter module, focusing on optimization and test coverage improvements. The key changes include adding support for
BiasGeluand additionalErfGelupatterns, extendingSkipLayerNormalizationto handle bias addition, and updating test utilities for better accuracy validation.New fusion patterns:
BiasGelu Fusion: Added a new fusion pattern for
BiasGeluoperations, including its implementation inonnxscript/rewriter/ort_fusions/bias_gelu.pyand integration into thefuse_xformerspipeline. A corresponding unit test was added to validate the functionality. [1] [2] [3] [4]ErfGelu Enhancements: Introduced a second pattern for
ErfGelufusion and refactored the corresponding implementation. The file was renamed fromerfgelu.pytoort_fusions/erfgelu.pyfor consistency. [1] [2] [3] [4]Enhancements to existing fusions:
SkipLayerNormalizationfusion to support an additional bias term. This includes new patterns and rewrite rules inonnxscript/rewriter/ort_fusions/skip_normalization.py.Test utility updates:
assert_allcloseto1e-3for better handling of numerical discrepancies in tests.