Commit 605e06e
authored
Add fusion rules (Whisper optimizations) (#2221)
Add fusion rules to support the optimization of Whisper models.
Fusions added:
- Basic Fusions:
* additional pattern for erfgelu [moved to #2222]
- SkipLayerNorm:
* #2259
* Fusion patterns where skip_sum is also an output
* Bias + SkipLayerNorm -> SkipLayerNorm (with bias) [moved to #2222]
- BiasGelu Fusion [moved to #2222]
- SDPA:
* Support for pattern where only q is pre-scaled
- MHA:
* Patterns with/without past/present keys/values
* Patterns with non-rotary embeddings
* Patterns with/without mask
* Patterns with cross-attention (only for past key/value patterns)
- MHA Bias Fusion:
* Bias was offloaded to Attention fusion previously, this fusion fixes
that
- Attention:
* Patterns where Q, K and V do not come from slicing
TODO:
- [x] Fix SDPA singular prescale case, due to lost shape information
- [x] - Enable check conditions when #2210 is merged
- [x] - Improve/Rewrite whisper model test case to be similar to that of
smollm (for eg)
- [x] - Fix failing test cases to account for new patterns
- [x] - Add isolated test cases for new fusions like BiasGelu,
SkipLayerNorm etc1 parent 33f31ca commit 605e06e
File tree
11 files changed
+1357
-283
lines changed- onnxscript/rewriter/ort_fusions
- models
11 files changed
+1357
-283
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
79 | 80 | | |
80 | 81 | | |
81 | 82 | | |
| 83 | + | |
| 84 | + | |
82 | 85 | | |
83 | 86 | | |
84 | 87 | | |
| |||
87 | 90 | | |
88 | 91 | | |
89 | 92 | | |
| 93 | + | |
90 | 94 | | |
91 | 95 | | |
| 96 | + | |
92 | 97 | | |
93 | 98 | | |
94 | 99 | | |
| |||
0 commit comments