|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +""" |
| 4 | +Support attention backend for flashMLA. |
| 5 | +
|
| 6 | +Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking. |
| 7 | +#TODO |
| 8 | +Support FlashMLA decode with cudagraph |
| 9 | +Enable speculative sampling in FlashMLA |
| 10 | +Integrate FA3 prefill |
| 11 | +""" |
| 12 | + |
| 13 | + |
| 14 | +from typing import TYPE_CHECKING, Optional, Union |
| 15 | + |
| 16 | +import torch |
| 17 | +import triton |
| 18 | +from flash_mla import flash_mla_with_kvcache, get_mla_metadata |
| 19 | + |
| 20 | +from sglang.global_config import global_config |
| 21 | +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend |
| 22 | +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend |
| 23 | +from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton |
| 24 | +from sglang.srt.layers.dp_attention import get_attention_tp_size |
| 25 | +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode |
| 26 | + |
| 27 | +if TYPE_CHECKING: |
| 28 | + from sglang.srt.layers.radix_attention import RadixAttention |
| 29 | + from sglang.srt.model_executor.model_runner import ModelRunner |
| 30 | + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput |
| 31 | + |
| 32 | + |
| 33 | +# FlashMLA only supports pagesize=64 |
| 34 | +PAGE_SIZE = 64 |
| 35 | + |
| 36 | + |
| 37 | +class FlashMLABackend(FlashInferMLAAttnBackend): |
| 38 | + """Flashinfer attention kernels.""" |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + model_runner: ModelRunner, |
| 43 | + skip_prefill: bool = False, |
| 44 | + kv_indptr_buf: Optional[torch.Tensor] = None, |
| 45 | + kv_last_page_len_buf: Optional[torch.Tensor] = None, |
| 46 | + ): |
| 47 | + super().__init__( |
| 48 | + model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf |
| 49 | + ) |
| 50 | + |
| 51 | + self.num_q_heads = ( |
| 52 | + model_runner.model_config.num_attention_heads // get_attention_tp_size() |
| 53 | + ) |
| 54 | + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( |
| 55 | + get_attention_tp_size() |
| 56 | + ) |
| 57 | + self.req_to_token = model_runner.req_to_token_pool.req_to_token |
| 58 | + self.num_local_heads = ( |
| 59 | + model_runner.model_config.num_attention_heads // get_attention_tp_size() |
| 60 | + ) |
| 61 | + self.kv_lora_rank = model_runner.model_config.kv_lora_rank |
| 62 | + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim |
| 63 | + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim |
| 64 | + self.v_head_dim = model_runner.model_config.v_head_dim |
| 65 | + self.scaling = model_runner.model_config.scaling |
| 66 | + self.data_type = model_runner.kv_cache_dtype |
| 67 | + self.q_data_type = model_runner.dtype |
| 68 | + self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim |
| 69 | + |
| 70 | + def forward_decode( |
| 71 | + self, |
| 72 | + q: torch.Tensor, |
| 73 | + k: torch.Tensor, |
| 74 | + v: torch.Tensor, |
| 75 | + layer: RadixAttention, |
| 76 | + forward_batch: ForwardBatch, |
| 77 | + save_kv_cache: bool = True, |
| 78 | + ): |
| 79 | + cache_loc = forward_batch.out_cache_loc |
| 80 | + |
| 81 | + if k is not None: |
| 82 | + assert v is not None |
| 83 | + if save_kv_cache: |
| 84 | + forward_batch.token_to_kv_pool.set_kv_buffer( |
| 85 | + layer, |
| 86 | + cache_loc, |
| 87 | + k, |
| 88 | + v, |
| 89 | + ) |
| 90 | + bs = forward_batch.batch_size |
| 91 | + |
| 92 | + max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE) |
| 93 | + flashmla_index = torch.full( |
| 94 | + (bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device |
| 95 | + ) |
| 96 | + create_flashmla_kv_indices_triton[(bs,)]( |
| 97 | + self.indices_updater_decode.req_to_token, |
| 98 | + forward_batch.req_pool_indices, |
| 99 | + forward_batch.seq_lens, |
| 100 | + None, |
| 101 | + flashmla_index, |
| 102 | + self.indices_updater_decode.req_to_token.size(1), |
| 103 | + flashmla_index.size(1), |
| 104 | + max_seqlen_pad, |
| 105 | + ) |
| 106 | + |
| 107 | + mla_metadata, mla_splits = get_mla_metadata( |
| 108 | + forward_batch.seq_lens.to(torch.int32), |
| 109 | + 1 * self.num_q_heads // self.num_kv_heads, |
| 110 | + self.num_kv_heads, |
| 111 | + ) |
| 112 | + |
| 113 | + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) |
| 114 | + reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) |
| 115 | + |
| 116 | + o, _ = flash_mla_with_kvcache( |
| 117 | + q=reshape_q, |
| 118 | + k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim), |
| 119 | + block_table=flashmla_index, |
| 120 | + cache_seqlens=forward_batch.seq_lens.to(torch.int32), |
| 121 | + head_dim_v=self.kv_lora_rank, # TODO Retrieve from config. |
| 122 | + tile_scheduler_metadata=mla_metadata, |
| 123 | + num_splits=mla_splits, |
| 124 | + softmax_scale=layer.scaling, |
| 125 | + causal=False, |
| 126 | + ) |
| 127 | + |
| 128 | + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) |
0 commit comments