Support FlashMLA backend cuda graph#4514
Conversation
Co-authored-by: yinfan98 <1106310035@qq.com> Co-authored-by: Hongbosherlock <hongbosherlock@gmail.com>
38f575a to
c28f9bc
Compare
Co-authored-by: yinfan98 <1106310035@qq.com> Co-authored-by: Hongbosherlock <hongbosherlock@gmail.com>
4a47be2 to
4b790da
Compare
|
|
||
| if forward_mode.is_decode_or_idle(): | ||
| seq_lens = seq_lens[:bs] | ||
| max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) |
There was a problem hiding this comment.
We should avoid CPU-GPU synchronization by avoiding the use of seq_lens.max().item().
Can you derive this value from seq_lens_cpu?
There was a problem hiding this comment.
I found this problem. I was too busy at work today and didn't have time to modify it. I will fix it tomorrow.
|
I attempted to test FlashMLA + CUDA Graph on your commit, but I was not successful. The following error occurred: The test command I used is: However, everything works fine as long as I don't add Environment:
|
|
hi @sleepcoo Great pr! but I did some simple tests, and it seems that the performance of flashmla is not as good as of triton_backend. What could be the reason? command:python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3 --max-concurrency 1 --random-input 128 --random-output 1024 --dataset-path /models/dataset/ShareGPT_V3_unfiltered_cleaned_split.json flashmla:python3 -m sglang.launch_server --model-path /models/DeepSeek-R1 triton_backend:python3 -m sglang.launch_server --model-path /models/DeepSeek-R1 result:flashmla:============ Serving Benchmark Result ============ triton_backend:============ Serving Benchmark Result ============ Environment:
|
In this PR, we have fixed the performance issues and tested it. In certain cases, flashmla has advantages. |
Motivation
Support FlashMLA backend cuda graph. Optimize index calculation, complete the calculation in init_forward
Modifications
Test
deepseekV3 accuracy test
GSM8K Accuracy: 0.980
MMLU Average accuracy: 0.878
todo