Support enable_gqa and only support 4D Q, K, and V#2558
Support enable_gqa and only support 4D Q, K, and V#2558titaiwangms merged 1 commit intomicrosoft:mainfrom
enable_gqa and only support 4D Q, K, and V#2558Conversation
| matcher=lambda sample: len(sample.input.shape) != 4 | ||
| or len(sample.args[0].shape) != 4 | ||
| or len(sample.args[1].shape) != 4, | ||
| reason="torch sdpa is expected to pass in 4d q, k, and v.", |
There was a problem hiding this comment.
@justinchuby @xadupre Let me know what you think on whether we should support only 4d QKV, or we should fully support whatever torch sdpa supports. Right now, it seems like QKV can have 3d or 4d or even q 3d and kv 4d in torch sdpa.
There was a problem hiding this comment.
Depending on the ATen op? Does the nn function do preprocessing on the inputs before sending them to the kernel? We just need to support whatever the kernel supports
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2558 +/- ##
==========================================
- Coverage 70.34% 70.32% -0.03%
==========================================
Files 218 222 +4
Lines 26430 26645 +215
Branches 2647 2663 +16
==========================================
+ Hits 18593 18738 +145
- Misses 6934 6991 +57
- Partials 903 916 +13 ☔ View full report in Codecov by Sentry. |
|
Could you also add these few lines as a micro optimization? Or we can do that separately |
it's already there: |
Fixes #162258 Related to microsoft/onnxscript#2558 Pull Request resolved: #162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
enable_gqaNOTE: torch.nn.functional.scaled_dot_product_attention actually supports 3D, and even Q-3D with K and V - 4D in op tests.