[bugfix]: fix SDPAImpl.forward() argument mismatch when VSA backend unavailable#1183
[bugfix]: fix SDPAImpl.forward() argument mismatch when VSA backend unavailable#1183Kash6 wants to merge 2 commits intohao-ai-lab:mainfrom
Conversation
…ao-ai-lab#817) DistributedAttention_VSA unconditionally passed gate_compress to attn_impl.forward(), but only VideoSparseAttentionImpl accepts it. When the platform resolves a different backend (e.g. SDPA on MPS), this caused TypeError: SDPAImpl.forward() takes 5 positional arguments but 6 were given. Only pass gate_compress when the resolved backend is VIDEO_SPARSE_ATTN.
There was a problem hiding this comment.
Welcome to FastVideo! Thanks for your first pull request.
How our CI works:
PRs run a two-tier CI system:
- Pre-commit — formatting (yapf), linting (ruff), type checking (mypy). Runs immediately on every PR.
- Fastcheck — core GPU tests (encoders, VAEs, transformers, kernels, unit tests). Runs automatically via Buildkite on relevant file changes (~10-15 min).
- Full Suite — integration tests, training pipelines, SSIM regression. Runs only when a reviewer adds the
readylabel.
Before your PR is reviewed:
-
pre-commit run --all-filespasses locally - You've added or updated tests for your changes
- The PR description explains what and why
If pre-commit fails, a bot comment will explain how to fix it. Fastcheck and Full Suite results appear in the Checks section below.
Useful links:
There was a problem hiding this comment.
Code Review
This pull request updates the attention layer's forward method to conditionally pass the gate_compress tensor based on the active backend, resolving an argument mismatch issue. Feedback suggests optimizing performance by conditionally including gate_compress in the processing pipeline only when required by the VIDEO_SPARSE_ATTN backend to avoid redundant computations and data transfers.
|
There must be a simpler way to deal with this than polluting the |
agree |
|
@loaydatrain @alexzms Sure, adding an early error in mps.py when VSA is requested would give users a much clearer message. That said, the layer.py guard is still valuable as a safety net since the block type selection (WanTransformerBlock_VSA) and backend resolution are decoupled so the mismatch can happen on CUDA too (if VSA import succeeds at model init but a different backend is resolved for a specific attention layer). I can add the mps.py early error as an additional commit on this PR if needed. |
|
For now I think it is best to handle it in
Out of curiosity, when would such an error occur? Either way, I think the best thing to do in that case is to raise an error than incorrectly make the user think that VSA is being used. |
…ao-ai-lab#817) Revert the layer.py forward() guard. Instead, raise NotImplementedError in mps.py when VIDEO_SPARSE_ATTN is requested, giving users a clear message to unset the env var or use TORCH_SDPA. This avoids silently dropping gate_compress and incorrectly running SDPA while the user thinks VSA is active.
|
Mb, I looked at this more carefully and the CUDA fallback scenario I described doesn't actually happen. On CUDA, if VSA import fails it raises ImportError immediately in cuda.py rather than silently falling back. The only realworld path to this bug is MPS, which always returns SDPA regardless of the env var. Props for the silent degradation. Updated the PR: reverted the layer.py change and added a NotImplementedError in mps.py when VIDEO_SPARSE_ATTN is requested. Looking forward to the *_VSA refactor. |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🟠 PR merge requirementsWaiting for:
Waiting checks:
|
|
Updated as per feedback, moved the fix to mps.py and reverted layer.py. Ready for re-review when you get a chance @loaydatrain @alexzms |
|
@alexzms @loaydatrain Looks like Mergify is waiting on full-suite-passed and pre-commit checks, do I need the ready label added to trigger those? |
Purpose
Fixes #817
When FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN is set but the VSA kernel isn't available (e.g. on MPS/macOS), the model still constructs WanTransformerBlock_VSA blocks based on the env var string. However, the actual attention backend is resolved independently by the platform, MPS always returns SDPA.
This causes DistributedAttention_VSA.forward() to pass gate_compress as a 5th positional arg to SDPAImpl.forward(), which only accepts 4:
Root Cause
Two decisions are decoupled:
wanvideo.py:587 picks WanTransformerBlock_VSA based on the env var string
The actual attn_impl backend is resolved by the platform (MPS is always SDPA)
Only VideoSparseAttentionImpl.forward() accepts gate_compress. All other backends (SDPA, FlashAttn, SageAttn) follow the standard (query, key, value, attn_metadata) signature.
Changes
Check self.backend in DistributedAttention_VSA.forward() before deciding whether to pass gate_compress. Only VIDEO_SPARSE_ATTN gets it; all other backends get the standard 4-arg call.
Test Plan
Test Results
-pre-commit run --all-files: all checks pass (yapf, ruff, codespell, actionlint, filenames). mypy has 4 pre-existing errors on main (unrelated VideoGenerator import).
-pytest tests/: 6 passed, 23 skipped, 3 failed — all 3 failures are pre-existing on main. Zero new regressions.
Checklist
pre-commit run --all-filesand fixed all issuesFor model/pipeline changes, also check:
Note: #1182 takes a different approach by adding an unused gate_compress parameter to every attention backend. This PR instead checks the resolved backend at the call site in DistributedAttention_VSA, which is less invasive (1 file, 4 lines) and avoids polluting backend signatures with parameters they don't use. It also directly answers the maintainer's question on #1182 , DistributedAttention_VSA gets a non-VSA backend because wanvideo.py selects the block type from the env var string, while the actual backend is resolved independently by the platform (MPS always returns SDPA).