Skip to content

Support enable_gqa and only support 4D Q, K, and V#2558

Merged
titaiwangms merged 1 commit intomicrosoft:mainfrom
titaiwangms:titaiwang/support_enable_gqa
Sep 12, 2025
Merged

Support enable_gqa and only support 4D Q, K, and V#2558
titaiwangms merged 1 commit intomicrosoft:mainfrom
titaiwangms:titaiwang/support_enable_gqa

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

  1. Support enable_gqa
  2. Align PyTorch setting to unsupport Q, K, and V when they are not 4D: https://github.com/pytorch/pytorch/blob/62843c14bbf694f5722fd6e1075da4792507fe42/torch/onnx/_internal/exporter/_torchlib/ops/nn.py#L131-L133

NOTE: torch.nn.functional.scaled_dot_product_attention actually supports 3D, and even Q-3D with K and V - 4D in op tests.

matcher=lambda sample: len(sample.input.shape) != 4
or len(sample.args[0].shape) != 4
or len(sample.args[1].shape) != 4,
reason="torch sdpa is expected to pass in 4d q, k, and v.",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby @xadupre Let me know what you think on whether we should support only 4d QKV, or we should fully support whatever torch sdpa supports. Right now, it seems like QKV can have 3d or 4d or even q 3d and kv 4d in torch sdpa.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the ATen op? Does the nn function do preprocessing on the inputs before sending them to the kernel? We just need to support whatever the kernel supports

@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 14.28571% with 24 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.32%. Comparing base (647b22a) to head (e661531).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 14.28% 22 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2558      +/-   ##
==========================================
- Coverage   70.34%   70.32%   -0.03%     
==========================================
  Files         218      222       +4     
  Lines       26430    26645     +215     
  Branches     2647     2663      +16     
==========================================
+ Hits        18593    18738     +145     
- Misses       6934     6991      +57     
- Partials      903      916      +13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Sep 11, 2025
@justinchuby
Copy link
Copy Markdown
Collaborator

Could you also add these few lines

if dropout_p > 0.0:
    attn_weight, _ = op.Dropout(attn_weight, dropout_p)

as a micro optimization? Or we can do that separately

@titaiwangms
Copy link
Copy Markdown
Contributor Author

Could you also add these few lines

if dropout_p > 0.0:
    attn_weight, _ = op.Dropout(attn_weight, dropout_p)

as a micro optimization? Or we can do that separately

it's already there:

attn_weight, _ = op.Dropout(attn_weight, dropout_p)

@titaiwangms titaiwangms enabled auto-merge (squash) September 12, 2025 00:37
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Sep 12, 2025
@titaiwangms titaiwangms merged commit 8ed3521 into microsoft:main Sep 12, 2025
32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

2 participants