Skip to content

Commit a53fe42

Browse files
sleepcooFlamingoPg
andauthored
Support FlashMLA backend (#4472)
Co-authored-by: yinfan98 <1106310035@qq.com>
1 parent 1b85929 commit a53fe42

6 files changed

Lines changed: 209 additions & 1 deletion

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
"""
4+
Support attention backend for flashMLA.
5+
6+
Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking.
7+
#TODO
8+
Support FlashMLA decode with cudagraph
9+
Enable speculative sampling in FlashMLA
10+
Integrate FA3 prefill
11+
"""
12+
13+
14+
from typing import TYPE_CHECKING, Optional, Union
15+
16+
import torch
17+
import triton
18+
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
19+
20+
from sglang.global_config import global_config
21+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
22+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
23+
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
24+
from sglang.srt.layers.dp_attention import get_attention_tp_size
25+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
26+
27+
if TYPE_CHECKING:
28+
from sglang.srt.layers.radix_attention import RadixAttention
29+
from sglang.srt.model_executor.model_runner import ModelRunner
30+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
31+
32+
33+
# FlashMLA only supports pagesize=64
34+
PAGE_SIZE = 64
35+
36+
37+
class FlashMLABackend(FlashInferMLAAttnBackend):
38+
"""Flashinfer attention kernels."""
39+
40+
def __init__(
41+
self,
42+
model_runner: ModelRunner,
43+
skip_prefill: bool = False,
44+
kv_indptr_buf: Optional[torch.Tensor] = None,
45+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
46+
):
47+
super().__init__(
48+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
49+
)
50+
51+
self.num_q_heads = (
52+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
53+
)
54+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
55+
get_attention_tp_size()
56+
)
57+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
58+
self.num_local_heads = (
59+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
60+
)
61+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
62+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
63+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
64+
self.v_head_dim = model_runner.model_config.v_head_dim
65+
self.scaling = model_runner.model_config.scaling
66+
self.data_type = model_runner.kv_cache_dtype
67+
self.q_data_type = model_runner.dtype
68+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
69+
70+
def forward_decode(
71+
self,
72+
q: torch.Tensor,
73+
k: torch.Tensor,
74+
v: torch.Tensor,
75+
layer: RadixAttention,
76+
forward_batch: ForwardBatch,
77+
save_kv_cache: bool = True,
78+
):
79+
cache_loc = forward_batch.out_cache_loc
80+
81+
if k is not None:
82+
assert v is not None
83+
if save_kv_cache:
84+
forward_batch.token_to_kv_pool.set_kv_buffer(
85+
layer,
86+
cache_loc,
87+
k,
88+
v,
89+
)
90+
bs = forward_batch.batch_size
91+
92+
max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE)
93+
flashmla_index = torch.full(
94+
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device
95+
)
96+
create_flashmla_kv_indices_triton[(bs,)](
97+
self.indices_updater_decode.req_to_token,
98+
forward_batch.req_pool_indices,
99+
forward_batch.seq_lens,
100+
None,
101+
flashmla_index,
102+
self.indices_updater_decode.req_to_token.size(1),
103+
flashmla_index.size(1),
104+
max_seqlen_pad,
105+
)
106+
107+
mla_metadata, mla_splits = get_mla_metadata(
108+
forward_batch.seq_lens.to(torch.int32),
109+
1 * self.num_q_heads // self.num_kv_heads,
110+
self.num_kv_heads,
111+
)
112+
113+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
114+
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
115+
116+
o, _ = flash_mla_with_kvcache(
117+
q=reshape_q,
118+
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
119+
block_table=flashmla_index,
120+
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
121+
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
122+
tile_scheduler_metadata=mla_metadata,
123+
num_splits=mla_splits,
124+
softmax_scale=layer.scaling,
125+
causal=False,
126+
)
127+
128+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

python/sglang/srt/layers/attention/utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def create_flashinfer_kv_indices_triton(
1515
BLOCK_SIZE: tl.constexpr = 512
1616
pid = tl.program_id(axis=0)
1717

18+
# find the req pool idx, this is for batch to token
1819
req_pool_index = tl.load(req_pool_indices_ptr + pid)
1920
kv_indices_offset = tl.load(kv_indptr + pid)
2021

@@ -37,3 +38,56 @@ def create_flashinfer_kv_indices_triton(
3738
mask=mask,
3839
)
3940
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
41+
42+
43+
@triton.jit
44+
def create_flashmla_kv_indices_triton(
45+
req_to_token_ptr, # [max_batch, max_context_len]
46+
req_pool_indices_ptr,
47+
page_kernel_lens_ptr,
48+
kv_start_idx,
49+
kv_indices_ptr,
50+
req_to_token_ptr_stride: tl.constexpr,
51+
kv_indices_ptr_stride: tl.constexpr,
52+
max_pagesize: tl.constexpr,
53+
):
54+
PAGED_SIZE: tl.constexpr = 64
55+
BLOCK_SIZE: tl.constexpr = 4096
56+
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
57+
pid = tl.program_id(axis=0)
58+
59+
# find the req pool idx, this is for batch to token
60+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
61+
62+
kv_start = 0
63+
kv_end = 0
64+
if kv_start_idx:
65+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
66+
kv_end = kv_start
67+
68+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
69+
70+
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
71+
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
72+
73+
for i in range(num_pages_loop):
74+
paged_offset = (
75+
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
76+
) * PAGED_SIZE
77+
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
78+
79+
mask = paged_offset <= num_paged * PAGED_SIZE
80+
mask_out = paged_offset_out <= num_paged
81+
82+
data = tl.load(
83+
req_to_token_ptr
84+
+ req_pool_index * req_to_token_ptr_stride
85+
+ kv_start
86+
+ paged_offset,
87+
mask=mask,
88+
)
89+
tl.store(
90+
kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out,
91+
data // PAGED_SIZE,
92+
mask=mask_out,
93+
)

python/sglang/srt/managers/schedule_batch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
7272
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
7373
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
74+
"enable_flashmla": ServerArgs.enable_flashmla,
7475
"disable_radix_cache": ServerArgs.disable_radix_cache,
7576
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
7677
}
@@ -1273,7 +1274,10 @@ def merge_batch(self, other: "ScheduleBatch"):
12731274

12741275
def get_model_worker_batch(self) -> ModelWorkerBatch:
12751276
if self.forward_mode.is_decode_or_idle():
1276-
if global_server_args_dict["enable_flashinfer_mla"]:
1277+
if (
1278+
global_server_args_dict["enable_flashinfer_mla"]
1279+
or global_server_args_dict["enable_flashmla"]
1280+
):
12771281
decode_seq_lens = self.seq_lens.cpu()
12781282
else:
12791283
decode_seq_lens = None

python/sglang/srt/model_executor/model_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
150150
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
151151
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
152+
"enable_flashmla": server_args.enable_flashmla,
152153
"disable_radix_cache": server_args.disable_radix_cache,
153154
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
154155
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
@@ -223,6 +224,9 @@ def model_specific_adjustment(self):
223224
"MLA optimization is turned on. Use flashinfer mla backend."
224225
)
225226
server_args.attention_backend = "flashinfer_mla"
227+
elif server_args.enable_flashmla:
228+
logger.info("MLA optimization is turned on. Use flashmla decode.")
229+
server_args.attention_backend = "flashmla"
226230
else:
227231
logger.info("MLA optimization is turned on. Use triton backend.")
228232
server_args.attention_backend = "triton"
@@ -840,6 +844,10 @@ def init_attention_backend(self):
840844
)
841845

842846
self.attn_backend = FlashInferMLAAttnBackend(self)
847+
elif self.server_args.attention_backend == "flashmla":
848+
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
849+
850+
self.attn_backend = FlashMLABackend(self)
843851
else:
844852
raise ValueError(
845853
f"Invalid attention backend: {self.server_args.attention_backend}"

python/sglang/srt/server_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class ServerArgs:
173173
tool_call_parser: str = None
174174
enable_hierarchical_cache: bool = False
175175
enable_flashinfer_mla: bool = False
176+
enable_flashmla: bool = False
176177
flashinfer_mla_disable_ragged: bool = False
177178
warmups: Optional[str] = None
178179

@@ -227,6 +228,8 @@ def __post_init__(self):
227228

228229
assert self.chunked_prefill_size % self.page_size == 0
229230

231+
if self.enable_flashmla is True:
232+
assert self.page_size == 64, "FlashMLA only support page_size=64"
230233
# Set cuda graph max batch size
231234
if self.cuda_graph_max_bs is None:
232235
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -753,6 +756,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
753756
action="store_true",
754757
help="Enable FlashInfer MLA optimization",
755758
)
759+
parser.add_argument(
760+
"--enable-flashmla",
761+
action="store_true",
762+
help="Enable FlashMLA decode optimization",
763+
)
756764
parser.add_argument(
757765
"--flashinfer-mla-disable-ragged",
758766
action="store_true",

scripts/playground/bench_speculative.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ def main(args, server_args):
182182
"--enable-flashinfer-mla",
183183
]
184184
)
185+
if server_args.enable_flashmla:
186+
other_args.extend(
187+
[
188+
"--enable-flashmla",
189+
]
190+
)
185191

186192
if server_args.quantization:
187193
other_args.extend(

0 commit comments

Comments
 (0)