Skip to content

Commit b6944f9

Browse files
sleepcooFlamingoPgHongbosherlockispobock
authored
Support FlashMLA backend cuda graph (#4514)
Co-authored-by: yinfan98 <1106310035@qq.com> Co-authored-by: Hongbosherlock <hongbosherlock@gmail.com> Co-authored-by: ispobock <ispobaoke@163.com>
1 parent f44db16 commit b6944f9

3 files changed

Lines changed: 188 additions & 32 deletions

File tree

python/sglang/srt/layers/attention/flashmla_backend.py

Lines changed: 184 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from __future__ import annotations
22

33
"""
4-
Support attention backend for flashMLA.
4+
Support attention backend for FlashMLA.
55
6-
Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking.
76
#TODO
8-
Support FlashMLA decode with cudagraph
97
Enable speculative sampling in FlashMLA
10-
Integrate FA3 prefill
118
"""
129

13-
10+
from dataclasses import dataclass
1411
from typing import TYPE_CHECKING, Optional, Union
1512

1613
import torch
@@ -28,10 +25,30 @@
2825
from sglang.srt.layers.radix_attention import RadixAttention
2926
from sglang.srt.model_executor.model_runner import ModelRunner
3027
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
28+
from sglang.srt.speculative.spec_info import SpecInfo
3129

3230

3331
# FlashMLA only supports pagesize=64
3432
PAGE_SIZE = 64
33+
# TODO The current setup is hard-coded and will be changed after integrating with MTP.
34+
Q_LEN = 1
35+
36+
37+
@dataclass
38+
class FlashMLADecodeMetadata:
39+
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
40+
num_splits: Optional[torch.Tensor] = None
41+
block_kv_indices: Optional[torch.Tensor] = None
42+
43+
def __init__(
44+
self,
45+
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
46+
num_splits: Optional[torch.Tensor] = None,
47+
block_kv_indices: Optional[torch.Tensor] = None,
48+
):
49+
self.flashmla_metadata = flashmla_metadata
50+
self.num_splits = num_splits
51+
self.block_kv_indices = block_kv_indices
3552

3653

3754
class FlashMLABackend(FlashInferMLAAttnBackend):
@@ -58,6 +75,7 @@ def __init__(
5875
self.num_local_heads = (
5976
model_runner.model_config.num_attention_heads // get_attention_tp_size()
6077
)
78+
self.forward_metadata: Union[FlashMLADecodeMetadata] = None
6179
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
6280
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
6381
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
@@ -67,6 +85,163 @@ def __init__(
6785
self.q_data_type = model_runner.dtype
6886
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
6987

88+
def init_forward_metadata(self, forward_batch: ForwardBatch):
89+
90+
bs = forward_batch.batch_size
91+
spec_info = forward_batch.spec_info
92+
if forward_batch.forward_mode.is_decode_or_idle():
93+
if spec_info is None:
94+
max_seqlen_pad = triton.cdiv(
95+
forward_batch.seq_lens.max().item(), PAGE_SIZE
96+
)
97+
block_kv_indices = torch.full(
98+
(bs, max_seqlen_pad),
99+
-1,
100+
dtype=torch.int32,
101+
device=forward_batch.seq_lens.device,
102+
)
103+
create_flashmla_kv_indices_triton[(bs,)](
104+
self.req_to_token,
105+
forward_batch.req_pool_indices,
106+
forward_batch.seq_lens,
107+
None,
108+
block_kv_indices,
109+
self.req_to_token.stride(0),
110+
max_seqlen_pad,
111+
)
112+
mla_metadata, num_splits = get_mla_metadata(
113+
forward_batch.seq_lens.to(torch.int32),
114+
Q_LEN * self.num_q_heads // self.num_kv_heads,
115+
self.num_kv_heads,
116+
)
117+
self.forward_metadata = FlashMLADecodeMetadata(
118+
mla_metadata,
119+
num_splits,
120+
block_kv_indices,
121+
)
122+
else:
123+
super().init_forward_metadata(forward_batch)
124+
else:
125+
super().init_forward_metadata(forward_batch)
126+
127+
def init_cuda_graph_state(
128+
self,
129+
max_bs: int,
130+
block_kv_indices: Optional[torch.Tensor] = None,
131+
):
132+
if block_kv_indices is None:
133+
cuda_graph_kv_indices = torch.full(
134+
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
135+
1,
136+
dtype=torch.int32,
137+
device="cuda",
138+
)
139+
else:
140+
cuda_graph_kv_indices = block_kv_indices
141+
142+
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
143+
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
144+
Q_LEN * self.num_q_heads // self.num_kv_heads,
145+
self.num_kv_heads,
146+
)
147+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
148+
149+
def init_forward_metadata_capture_cuda_graph(
150+
self,
151+
bs: int,
152+
num_tokens: int,
153+
req_pool_indices: torch.Tensor,
154+
seq_lens: torch.Tensor,
155+
encoder_lens: Optional[torch.Tensor],
156+
forward_mode: ForwardMode,
157+
spec_info: Optional[SpecInfo],
158+
):
159+
if forward_mode.is_decode_or_idle():
160+
if spec_info is None:
161+
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
162+
163+
create_flashmla_kv_indices_triton[(bs,)](
164+
self.req_to_token,
165+
req_pool_indices,
166+
seq_lens,
167+
None,
168+
self.cuda_graph_kv_indices,
169+
self.req_to_token.stride(0),
170+
self.cuda_graph_kv_indices.stride(0),
171+
)
172+
mla_metadata, num_splits = get_mla_metadata(
173+
seq_lens.to(torch.int32),
174+
Q_LEN * self.num_q_heads // self.num_kv_heads,
175+
self.num_kv_heads,
176+
)
177+
self.cuda_graph_mla_metadata.copy_(mla_metadata)
178+
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
179+
self.forward_metadata = FlashMLADecodeMetadata(
180+
self.cuda_graph_mla_metadata,
181+
self.cuda_graph_num_splits[: bs + 1],
182+
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
183+
)
184+
185+
else:
186+
super().init_forward_metadata_capture_cuda_graph(
187+
bs,
188+
num_tokens,
189+
req_pool_indices,
190+
seq_lens,
191+
encoder_lens,
192+
forward_mode,
193+
spec_info,
194+
)
195+
196+
def init_forward_metadata_replay_cuda_graph(
197+
self,
198+
bs: int,
199+
req_pool_indices: torch.Tensor,
200+
seq_lens: torch.Tensor,
201+
seq_lens_sum: int,
202+
encoder_lens: Optional[torch.Tensor],
203+
forward_mode: ForwardMode,
204+
spec_info: Optional[SpecInfo],
205+
seq_lens_cpu: Optional[torch.Tensor],
206+
):
207+
208+
if forward_mode.is_decode_or_idle():
209+
seq_lens = seq_lens[:bs]
210+
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
211+
create_flashmla_kv_indices_triton[(bs,)](
212+
self.req_to_token,
213+
req_pool_indices[:bs],
214+
seq_lens,
215+
None,
216+
self.cuda_graph_kv_indices,
217+
self.req_to_token.stride(0),
218+
self.cuda_graph_kv_indices.stride(0),
219+
)
220+
mla_metadata, num_splits = get_mla_metadata(
221+
seq_lens.to(torch.int32),
222+
Q_LEN * self.num_q_heads // self.num_kv_heads,
223+
self.num_kv_heads,
224+
)
225+
self.cuda_graph_mla_metadata.copy_(mla_metadata)
226+
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
227+
self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
228+
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
229+
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
230+
:bs, :max_seqlen_pad
231+
]
232+
233+
else:
234+
super().init_forward_metadata_replay_cuda_graph(
235+
bs,
236+
req_pool_indices,
237+
seq_lens,
238+
seq_lens_sum,
239+
encoder_lens,
240+
forward_mode,
241+
spec_info,
242+
seq_lens_cpu,
243+
)
244+
70245
def forward_decode(
71246
self,
72247
q: torch.Tensor,
@@ -88,39 +263,18 @@ def forward_decode(
88263
v,
89264
)
90265
bs = forward_batch.batch_size
91-
92-
max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE)
93-
flashmla_index = torch.full(
94-
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device
95-
)
96-
create_flashmla_kv_indices_triton[(bs,)](
97-
self.indices_updater_decode.req_to_token,
98-
forward_batch.req_pool_indices,
99-
forward_batch.seq_lens,
100-
None,
101-
flashmla_index,
102-
self.indices_updater_decode.req_to_token.size(1),
103-
flashmla_index.size(1),
104-
max_seqlen_pad,
105-
)
106-
107-
mla_metadata, mla_splits = get_mla_metadata(
108-
forward_batch.seq_lens.to(torch.int32),
109-
1 * self.num_q_heads // self.num_kv_heads,
110-
self.num_kv_heads,
111-
)
112-
113266
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
267+
114268
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
115269

116270
o, _ = flash_mla_with_kvcache(
117271
q=reshape_q,
118272
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
119-
block_table=flashmla_index,
273+
block_table=self.forward_metadata.block_kv_indices,
120274
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
121275
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
122-
tile_scheduler_metadata=mla_metadata,
123-
num_splits=mla_splits,
276+
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
277+
num_splits=self.forward_metadata.num_splits,
124278
softmax_scale=layer.scaling,
125279
causal=False,
126280
)

python/sglang/srt/layers/attention/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def create_flashmla_kv_indices_triton(
4949
kv_indices_ptr,
5050
req_to_token_ptr_stride: tl.constexpr,
5151
kv_indices_ptr_stride: tl.constexpr,
52-
max_pagesize: tl.constexpr,
5352
):
5453
PAGED_SIZE: tl.constexpr = 64
5554
BLOCK_SIZE: tl.constexpr = 4096

python/sglang/srt/server_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ def __post_init__(self):
232232
assert self.chunked_prefill_size % self.page_size == 0
233233

234234
if self.enable_flashmla is True:
235-
assert self.page_size == 64, "FlashMLA only support page_size=64"
235+
logger.warning(
236+
"FlashMLA only supports a page_size of 64, change page_size to 64."
237+
)
238+
self.page_size = 64
236239
# Set cuda graph max batch size
237240
if self.cuda_graph_max_bs is None:
238241
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.

0 commit comments

Comments
 (0)