11from __future__ import annotations
22
33"""
4- Support attention backend for flashMLA .
4+ Support attention backend for FlashMLA .
55
6- Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking.
76#TODO
8- Support FlashMLA decode with cudagraph
97Enable speculative sampling in FlashMLA
10- Integrate FA3 prefill
118"""
129
13-
10+ from dataclasses import dataclass
1411from typing import TYPE_CHECKING , Optional , Union
1512
1613import torch
2825 from sglang .srt .layers .radix_attention import RadixAttention
2926 from sglang .srt .model_executor .model_runner import ModelRunner
3027 from sglang .srt .speculative .eagle_utils import EagleDraftInput , EagleVerifyInput
28+ from sglang .srt .speculative .spec_info import SpecInfo
3129
3230
3331# FlashMLA only supports pagesize=64
3432PAGE_SIZE = 64
33+ # TODO The current setup is hard-coded and will be changed after integrating with MTP.
34+ Q_LEN = 1
35+
36+
37+ @dataclass
38+ class FlashMLADecodeMetadata :
39+ flashmla_metadata : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None
40+ num_splits : Optional [torch .Tensor ] = None
41+ block_kv_indices : Optional [torch .Tensor ] = None
42+
43+ def __init__ (
44+ self ,
45+ flashmla_metadata : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
46+ num_splits : Optional [torch .Tensor ] = None ,
47+ block_kv_indices : Optional [torch .Tensor ] = None ,
48+ ):
49+ self .flashmla_metadata = flashmla_metadata
50+ self .num_splits = num_splits
51+ self .block_kv_indices = block_kv_indices
3552
3653
3754class FlashMLABackend (FlashInferMLAAttnBackend ):
@@ -58,6 +75,7 @@ def __init__(
5875 self .num_local_heads = (
5976 model_runner .model_config .num_attention_heads // get_attention_tp_size ()
6077 )
78+ self .forward_metadata : Union [FlashMLADecodeMetadata ] = None
6179 self .kv_lora_rank = model_runner .model_config .kv_lora_rank
6280 self .qk_nope_head_dim = model_runner .model_config .qk_nope_head_dim
6381 self .qk_rope_head_dim = model_runner .model_config .qk_rope_head_dim
@@ -67,6 +85,163 @@ def __init__(
6785 self .q_data_type = model_runner .dtype
6886 self .kv_cache_dim = self .kv_lora_rank + self .qk_rope_head_dim
6987
88+ def init_forward_metadata (self , forward_batch : ForwardBatch ):
89+
90+ bs = forward_batch .batch_size
91+ spec_info = forward_batch .spec_info
92+ if forward_batch .forward_mode .is_decode_or_idle ():
93+ if spec_info is None :
94+ max_seqlen_pad = triton .cdiv (
95+ forward_batch .seq_lens .max ().item (), PAGE_SIZE
96+ )
97+ block_kv_indices = torch .full (
98+ (bs , max_seqlen_pad ),
99+ - 1 ,
100+ dtype = torch .int32 ,
101+ device = forward_batch .seq_lens .device ,
102+ )
103+ create_flashmla_kv_indices_triton [(bs ,)](
104+ self .req_to_token ,
105+ forward_batch .req_pool_indices ,
106+ forward_batch .seq_lens ,
107+ None ,
108+ block_kv_indices ,
109+ self .req_to_token .stride (0 ),
110+ max_seqlen_pad ,
111+ )
112+ mla_metadata , num_splits = get_mla_metadata (
113+ forward_batch .seq_lens .to (torch .int32 ),
114+ Q_LEN * self .num_q_heads // self .num_kv_heads ,
115+ self .num_kv_heads ,
116+ )
117+ self .forward_metadata = FlashMLADecodeMetadata (
118+ mla_metadata ,
119+ num_splits ,
120+ block_kv_indices ,
121+ )
122+ else :
123+ super ().init_forward_metadata (forward_batch )
124+ else :
125+ super ().init_forward_metadata (forward_batch )
126+
127+ def init_cuda_graph_state (
128+ self ,
129+ max_bs : int ,
130+ block_kv_indices : Optional [torch .Tensor ] = None ,
131+ ):
132+ if block_kv_indices is None :
133+ cuda_graph_kv_indices = torch .full (
134+ (max_bs , (self .max_context_len + PAGE_SIZE ) // PAGE_SIZE ),
135+ 1 ,
136+ dtype = torch .int32 ,
137+ device = "cuda" ,
138+ )
139+ else :
140+ cuda_graph_kv_indices = block_kv_indices
141+
142+ self .cuda_graph_mla_metadata , self .cuda_graph_num_splits = get_mla_metadata (
143+ torch .ones (max_bs , dtype = torch .int32 , device = cuda_graph_kv_indices .device ),
144+ Q_LEN * self .num_q_heads // self .num_kv_heads ,
145+ self .num_kv_heads ,
146+ )
147+ self .cuda_graph_kv_indices = cuda_graph_kv_indices
148+
149+ def init_forward_metadata_capture_cuda_graph (
150+ self ,
151+ bs : int ,
152+ num_tokens : int ,
153+ req_pool_indices : torch .Tensor ,
154+ seq_lens : torch .Tensor ,
155+ encoder_lens : Optional [torch .Tensor ],
156+ forward_mode : ForwardMode ,
157+ spec_info : Optional [SpecInfo ],
158+ ):
159+ if forward_mode .is_decode_or_idle ():
160+ if spec_info is None :
161+ max_seqlen_pad = triton .cdiv (seq_lens .max ().item (), PAGE_SIZE )
162+
163+ create_flashmla_kv_indices_triton [(bs ,)](
164+ self .req_to_token ,
165+ req_pool_indices ,
166+ seq_lens ,
167+ None ,
168+ self .cuda_graph_kv_indices ,
169+ self .req_to_token .stride (0 ),
170+ self .cuda_graph_kv_indices .stride (0 ),
171+ )
172+ mla_metadata , num_splits = get_mla_metadata (
173+ seq_lens .to (torch .int32 ),
174+ Q_LEN * self .num_q_heads // self .num_kv_heads ,
175+ self .num_kv_heads ,
176+ )
177+ self .cuda_graph_mla_metadata .copy_ (mla_metadata )
178+ self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
179+ self .forward_metadata = FlashMLADecodeMetadata (
180+ self .cuda_graph_mla_metadata ,
181+ self .cuda_graph_num_splits [: bs + 1 ],
182+ self .cuda_graph_kv_indices [:bs , :max_seqlen_pad ],
183+ )
184+
185+ else :
186+ super ().init_forward_metadata_capture_cuda_graph (
187+ bs ,
188+ num_tokens ,
189+ req_pool_indices ,
190+ seq_lens ,
191+ encoder_lens ,
192+ forward_mode ,
193+ spec_info ,
194+ )
195+
196+ def init_forward_metadata_replay_cuda_graph (
197+ self ,
198+ bs : int ,
199+ req_pool_indices : torch .Tensor ,
200+ seq_lens : torch .Tensor ,
201+ seq_lens_sum : int ,
202+ encoder_lens : Optional [torch .Tensor ],
203+ forward_mode : ForwardMode ,
204+ spec_info : Optional [SpecInfo ],
205+ seq_lens_cpu : Optional [torch .Tensor ],
206+ ):
207+
208+ if forward_mode .is_decode_or_idle ():
209+ seq_lens = seq_lens [:bs ]
210+ max_seqlen_pad = triton .cdiv (seq_lens .max ().item (), PAGE_SIZE )
211+ create_flashmla_kv_indices_triton [(bs ,)](
212+ self .req_to_token ,
213+ req_pool_indices [:bs ],
214+ seq_lens ,
215+ None ,
216+ self .cuda_graph_kv_indices ,
217+ self .req_to_token .stride (0 ),
218+ self .cuda_graph_kv_indices .stride (0 ),
219+ )
220+ mla_metadata , num_splits = get_mla_metadata (
221+ seq_lens .to (torch .int32 ),
222+ Q_LEN * self .num_q_heads // self .num_kv_heads ,
223+ self .num_kv_heads ,
224+ )
225+ self .cuda_graph_mla_metadata .copy_ (mla_metadata )
226+ self .cuda_graph_num_splits [: bs + 1 ].copy_ (num_splits )
227+ self .forward_metadata .mla_metadata = self .cuda_graph_mla_metadata
228+ self .forward_metadata .num_splits = self .cuda_graph_num_splits [: bs + 1 ]
229+ self .forward_metadata .block_kv_indices = self .cuda_graph_kv_indices [
230+ :bs , :max_seqlen_pad
231+ ]
232+
233+ else :
234+ super ().init_forward_metadata_replay_cuda_graph (
235+ bs ,
236+ req_pool_indices ,
237+ seq_lens ,
238+ seq_lens_sum ,
239+ encoder_lens ,
240+ forward_mode ,
241+ spec_info ,
242+ seq_lens_cpu ,
243+ )
244+
70245 def forward_decode (
71246 self ,
72247 q : torch .Tensor ,
@@ -88,39 +263,18 @@ def forward_decode(
88263 v ,
89264 )
90265 bs = forward_batch .batch_size
91-
92- max_seqlen_pad = triton .cdiv (forward_batch .seq_lens .max ().item (), PAGE_SIZE )
93- flashmla_index = torch .full (
94- (bs , max_seqlen_pad ), - 1 , dtype = torch .int32 , device = q .device
95- )
96- create_flashmla_kv_indices_triton [(bs ,)](
97- self .indices_updater_decode .req_to_token ,
98- forward_batch .req_pool_indices ,
99- forward_batch .seq_lens ,
100- None ,
101- flashmla_index ,
102- self .indices_updater_decode .req_to_token .size (1 ),
103- flashmla_index .size (1 ),
104- max_seqlen_pad ,
105- )
106-
107- mla_metadata , mla_splits = get_mla_metadata (
108- forward_batch .seq_lens .to (torch .int32 ),
109- 1 * self .num_q_heads // self .num_kv_heads ,
110- self .num_kv_heads ,
111- )
112-
113266 k_cache = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id )
267+
114268 reshape_q = q .view (bs , - 1 , layer .tp_q_head_num , layer .head_dim )
115269
116270 o , _ = flash_mla_with_kvcache (
117271 q = reshape_q ,
118272 k_cache = k_cache .view (- 1 , PAGE_SIZE , 1 , self .kv_cache_dim ),
119- block_table = flashmla_index ,
273+ block_table = self . forward_metadata . block_kv_indices ,
120274 cache_seqlens = forward_batch .seq_lens .to (torch .int32 ),
121275 head_dim_v = self .kv_lora_rank , # TODO Retrieve from config.
122- tile_scheduler_metadata = mla_metadata ,
123- num_splits = mla_splits ,
276+ tile_scheduler_metadata = self . forward_metadata . flashmla_metadata ,
277+ num_splits = self . forward_metadata . num_splits ,
124278 softmax_scale = layer .scaling ,
125279 causal = False ,
126280 )
0 commit comments