Skip to content

[bugfix]: fix SDPAImpl.forward() argument mismatch when VSA backend unavailable#1183

Open
Kash6 wants to merge 2 commits intohao-ai-lab:mainfrom
Kash6:fix/sdpa-forward-arg-count
Open

[bugfix]: fix SDPAImpl.forward() argument mismatch when VSA backend unavailable#1183
Kash6 wants to merge 2 commits intohao-ai-lab:mainfrom
Kash6:fix/sdpa-forward-arg-count

Conversation

@Kash6
Copy link
Copy Markdown

@Kash6 Kash6 commented Mar 26, 2026

Purpose

Fixes #817

When FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN is set but the VSA kernel isn't available (e.g. on MPS/macOS), the model still constructs WanTransformerBlock_VSA blocks based on the env var string. However, the actual attention backend is resolved independently by the platform, MPS always returns SDPA.

This causes DistributedAttention_VSA.forward() to pass gate_compress as a 5th positional arg to SDPAImpl.forward(), which only accepts 4:

TypeError: SDPAImpl.forward() takes 5 positional arguments but 6 were given

Root Cause

Two decisions are decoupled:

wanvideo.py:587 picks WanTransformerBlock_VSA based on the env var string
The actual attn_impl backend is resolved by the platform (MPS is always SDPA)
Only VideoSparseAttentionImpl.forward() accepts gate_compress. All other backends (SDPA, FlashAttn, SageAttn) follow the standard (query, key, value, attn_metadata) signature.

Changes

Check self.backend in DistributedAttention_VSA.forward() before deciding whether to pass gate_compress. Only VIDEO_SPARSE_ATTN gets it; all other backends get the standard 4-arg call.

Test Plan

pre-commit run --all-files
pytest tests/ --ignore=tests/local_tests/pipelines/test_ltx2_pipeline_smoke.py -v

Test Results

-pre-commit run --all-files: all checks pass (yapf, ruff, codespell, actionlint, filenames). mypy has 4 pre-existing errors on main (unrelated VideoGenerator import).
-pytest tests/: 6 passed, 23 skipped, 3 failed — all 3 failures are pre-existing on main. Zero new regressions.

Checklist

  • I ran pre-commit run --all-files and fixed all issues
  • I added or updated tests for my changes
  • I updated documentation if needed
  • I considered GPU memory impact of my changes

For model/pipeline changes, also check:

  • I verified SSIM regression tests pass
  • I updated the support matrix if adding a new model

Note: #1182 takes a different approach by adding an unused gate_compress parameter to every attention backend. This PR instead checks the resolved backend at the call site in DistributedAttention_VSA, which is less invasive (1 file, 4 lines) and avoids polluting backend signatures with parameters they don't use. It also directly answers the maintainer's question on #1182 , DistributedAttention_VSA gets a non-VSA backend because wanvideo.py selects the block type from the env var string, while the actual backend is resolved independently by the platform (MPS always returns SDPA).

…ao-ai-lab#817)

DistributedAttention_VSA unconditionally passed gate_compress to
attn_impl.forward(), but only VideoSparseAttentionImpl accepts it.
When the platform resolves a different backend (e.g. SDPA on MPS),
this caused TypeError: SDPAImpl.forward() takes 5 positional
arguments but 6 were given.

Only pass gate_compress when the resolved backend is VIDEO_SPARSE_ATTN.
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welcome to FastVideo! Thanks for your first pull request.

How our CI works:

PRs run a two-tier CI system:

  1. Pre-commit — formatting (yapf), linting (ruff), type checking (mypy). Runs immediately on every PR.
  2. Fastcheck — core GPU tests (encoders, VAEs, transformers, kernels, unit tests). Runs automatically via Buildkite on relevant file changes (~10-15 min).
  3. Full Suite — integration tests, training pipelines, SSIM regression. Runs only when a reviewer adds the ready label.

Before your PR is reviewed:

  • pre-commit run --all-files passes locally
  • You've added or updated tests for your changes
  • The PR description explains what and why

If pre-commit fails, a bot comment will explain how to fix it. Fastcheck and Full Suite results appear in the Checks section below.

Useful links:

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the attention layer's forward method to conditionally pass the gate_compress tensor based on the active backend, resolving an argument mismatch issue. Feedback suggests optimizing performance by conditionally including gate_compress in the processing pipeline only when required by the VIDEO_SPARSE_ATTN backend to avoid redundant computations and data transfers.

Comment thread fastvideo/attention/layer.py Outdated
@loaydatrain
Copy link
Copy Markdown
Collaborator

There must be a simpler way to deal with this than polluting the layer.py code. How about in FastVideo/fastvideo/platforms/mps.py if VSA is selected with the MPS backend, you raise an error and the user gets notified that VSA isnt supported and how to unset the env var

@alexzms
Copy link
Copy Markdown
Collaborator

alexzms commented Mar 27, 2026

There must be a simpler way to deal with this than polluting the layer.py code. How about in FastVideo/fastvideo/platforms/mps.py if VSA is selected with the MPS backend, you raise an error and the user gets notified that VSA isnt supported and how to unset the env var

agree

@Kash6
Copy link
Copy Markdown
Author

Kash6 commented Mar 27, 2026

@loaydatrain @alexzms Sure, adding an early error in mps.py when VSA is requested would give users a much clearer message. That said, the layer.py guard is still valuable as a safety net since the block type selection (WanTransformerBlock_VSA) and backend resolution are decoupled so the mismatch can happen on CUDA too (if VSA import succeeds at model init but a different backend is resolved for a specific attention layer). I can add the mps.py early error as an additional commit on this PR if needed.

@loaydatrain
Copy link
Copy Markdown
Collaborator

loaydatrain commented Mar 28, 2026

For now I think it is best to handle it in mps.py and leave the attn impl as it is. In the near future I was planning a small refactor where I get rid of the *_VSA functions from both wanvideo.py and layer.py, which would be a stronger fix for this

(if VSA import succeeds at model init but a different backend is resolved for a specific attention layer)

Out of curiosity, when would such an error occur? Either way, I think the best thing to do in that case is to raise an error than incorrectly make the user think that VSA is being used.

…ao-ai-lab#817)

Revert the layer.py forward() guard. Instead, raise NotImplementedError
in mps.py when VIDEO_SPARSE_ATTN is requested, giving users a clear
message to unset the env var or use TORCH_SDPA.

This avoids silently dropping gate_compress and incorrectly running
SDPA while the user thinks VSA is active.
@Kash6
Copy link
Copy Markdown
Author

Kash6 commented Mar 28, 2026

Mb, I looked at this more carefully and the CUDA fallback scenario I described doesn't actually happen. On CUDA, if VSA import fails it raises ImportError immediately in cuda.py rather than silently falling back. The only realworld path to this bug is MPS, which always returns SDPA regardless of the env var.

Props for the silent degradation.

Updated the PR: reverted the layer.py change and added a NotImplementedError in mps.py when VIDEO_SPARSE_ATTN is requested. Looking forward to the *_VSA refactor.

@mergify mergify bot added the type: bugfix Bug fix label Mar 30, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 30, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🟠 PR merge requirements

Waiting for:

  • check-success=fastcheck-passed
  • check-success=full-suite-passed
Waiting checks: fastcheck-passed, full-suite-passed.
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • #approved-reviews-by>=1
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model)\]

@Kash6
Copy link
Copy Markdown
Author

Kash6 commented Apr 1, 2026

Updated as per feedback, moved the fix to mps.py and reverted layer.py. Ready for re-review when you get a chance @loaydatrain @alexzms

@Kash6
Copy link
Copy Markdown
Author

Kash6 commented Apr 8, 2026

@alexzms @loaydatrain Looks like Mergify is waiting on full-suite-passed and pre-commit checks, do I need the ready label added to trigger those?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] 1.6. INcompatibilities

3 participants