[BUGFIX][Mamba][Qwen3.5] Zero freed SSM cache blocks on GPU#35219
Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a bug fix for Mamba-based models, specifically addressing an issue where freed SSM cache blocks on the GPU were not being zeroed out. This could lead to incorrect state being reused in subsequent computations. The fix involves implementing a mechanism to track SSM blocks that are truly freed (i.e., their reference count drops to zero) during each scheduling step. These freed block IDs are then passed to the worker, which explicitly zeroes out the corresponding state tensors on the GPU. The changes are well-contained and correctly implemented across the scheduler and worker components, ensuring that Mamba's stateful cache is properly managed. The logic for identifying and collecting freed blocks is soundly integrated into the existing KV cache management lifecycle methods.
|
Could you please provide data on possible perf overhead? also with async scheduling I think it may be risky to zero on free, we may need to move this into the model runner to ensure it ends up in the correct order in the GPU stream |
Actual zeroing happens in |
|
Below is from slack discussion. But I think it's worth attention and sharing here Seems both FlashAttn and trt-llm attn use mul by 0 to mask not used values. |
|
How does this interact with prefix caching? If we zero out blocks when their ref_cnt hits zero, doesn't that mean they can't be re-used if something comes along later that gets a cache hit? Wouldn't it to be better to detect the event when we use a block for attention that was previously used for mamba (in some other dtype) and zero it out at that point? |
Ran on B200 With changes and without changes the Output Tokens vary around 15000+-300. Definitely nothing dramatical from perf point of view but not exact numbers. |
Good catch I broke the prefix caching :/ |
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
70dafd6 to
152ccc2
Compare
I redo this PR similar to what @tdoublep proposed. I decided to zero out every new block, whether it comes from attention or from the SSM. Justification: attention can also produce NaNs in certain corner cases. Getting garbage for one specific request is likely acceptable, but without zeroing, the NaN could propagate to all requests. |
|
pls take a look |
|
|
||
| def _zero_block_ids(self, block_ids: list[int]) -> None: | ||
| """Zero the raw KV cache memory for the given block IDs.""" | ||
| for raw_tensor, page_size in self.kv_cache_raw_buffers: |
There was a problem hiding this comment.
Would it be more efficient to build an index tensor and have one op to zero at all the block id slots?
There was a problem hiding this comment.
How many block ids would we normally see for a typical prefill/decode? Is it very few?
There was a problem hiding this comment.
This zeroing takes small amount of time. We do it once per forward step and only for new.
@benchislett Can you say right away does it code works in sync or async part?
There was a problem hiding this comment.
I notice that this is not specific to SSM blocks, and it clears all new KV blocks. Will this have a detrimental effect on prefills for non-mamba deployments where block_size=16?
In this case if we get a prefill of 8k tokens, that will be 512 new blocks, right? I think that would lead to 512 kernel invocations in this implementation. If that is indeed the case, this will not suffice.
There was a problem hiding this comment.
right,
I am optimizing it
There was a problem hiding this comment.
does it make sense to use torch.tensor for block ids and use a gpu operation to zero the indices in tensors?
There was a problem hiding this comment.
See my comment below
There was a problem hiding this comment.
does it make sense to use torch.tensor for block ids and use a gpu operation to zero the indices in tensors?
I implemented zeroing as a triton kernel
There was a problem hiding this comment.
Pls lets me know if there is a better way to do it
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> (cherry picked from commit 8c2fc11)
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> (cherry picked from commit 03a1823)
…ject#35219) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
PR vllm-project#35219 records every newly allocated full-attention/MLA block id into SingleTypeKVCacheManager.new_block_ids, but the scheduler only drains it via take_new_block_ids() when needs_kv_cache_zeroing, which equals has_mamba_layers. Models without Mamba layers therefore never drain the list, so it grows without bound and leaks host memory under sustained load (one int per allocated block per request). gc.freeze() at EngineCore startup excludes the list from gc.get_objects()/tracemalloc, which makes the growth easy to miss. Drain the per-step block ids unconditionally in the scheduler and only use them when zeroing is enabled. This bounds the list for all models without adding a constructor flag or reading needs_kv_cache_zeroing twice; for Mamba models the drain already happened in that branch, so their behavior is unchanged. Fixes vllm-project#44175 Signed-off-by: Ting Sun <suntcrick@gmail.com>
Essential problem
Fixes #35138
Workaround for Dao-AILab/flash-attention#1974
Hybrid models (e.g. Qwen3.5-397B-A17B) share a unified block pool between attention (fp8/fp16) and Mamba/SSM (fp32) layers. When a block previously used by Mamba (fp32 state) is reallocated to an attention layer with a smaller dtype, leftover fp32 bit patterns can appear as NaN/Inf in the new dtype. Attention kernels (FlashAttn3, FlashInfer-TRTLLM, etc.) use multiply-by-zero masking for unused positions, which does not clear NaN (
0 * NaN = NaN). The stale NaN then propagates across all requests sharing the same KV-cache block, causing progressive accuracy degradation over time.What this PR does
Zeroes GPU memory of freshly allocated full-attention KV-cache blocks before they are used, but only for hybrid models (models with Mamba layers). Mamba/SSM blocks are not zeroed (they overwrite their state fully on each step). The approach:
Scheduler side —
SingleTypeKVCacheManagertracks block IDs allocated since the last scheduling step (only forFullAttentionSpeclayers). After scheduling, the scheduler drains these IDs intoSchedulerOutput.new_block_ids_to_zero, gated behindself.has_mamba_layers.Worker side —
GPUModelRunner._update_states()receives the block IDs and calls_zero_block_ids(), which launches a single Triton kernel (_zero_kv_blocks_kernel) to zero the corresponding memory across all KV-cache segments in one GPU launch.Optimized zeroing — A one-time
_init_kv_zero_meta()precomputes absolute byte addresses of all KV-cache segments (handling both block-dim-0 and block-dim-1 layouts, multi-buffer backends, and virtual block splitting). Block IDs are transferred via pre-allocated pinned memory to overlap the H2D copy with kernel launch. This avoids 15 separateindex_fill_calls (For Qwen3.5-379B, one per layer).CuMem compatibility —
_init_kv_zero_meta()is called ingpu_worker.pyoutside the CuMem pool context, so the bookkeeping tensors (segment addresses, block-ID buffers) use the standard PyTorch allocator and survive sleep/wake cycles.Scope
self.has_mamba_layersin the scheduler. Non-hybrid (pure attention) models are completely unaffected.FullAttentionSpecblocks — Mamba/SSM blocks are not zeroed (they overwrite their state fully each step); only attention blocks that may inherit stale Mamba fp32 data are cleared.Performance overhead
Measured on B200, Qwen/Qwen3-0.6B, BS=500:
End-to-end benchmark on B200 (Qwen3.5-397B-A17B-FP8, TP=1 PP=1 DP=8, 2048 prompts, 500 output tokens) showed no measurable throughput degradation (output tokens/s within ±2% noise).
Test plan