Skip to content

Support FlashMLA backend#4472

Merged
zhyncs merged 7 commits intosgl-project:mainfrom
sleepcoo:support-flashmla
Mar 16, 2025
Merged

Support FlashMLA backend#4472
zhyncs merged 7 commits intosgl-project:mainfrom
sleepcoo:support-flashmla

Conversation

@sleepcoo
Copy link
Copy Markdown
Collaborator

@sleepcoo sleepcoo commented Mar 16, 2025

Motivation

Integrate flashmla for decoding, and the accuracy test is currently okay. The current implementation is quite simple, directly integrating flashmla as the backend. Later, we need to abstract a fastmla_backend, using fa3 for prefill and flashmla for decode

Modifications

  • FlashMLABackend inherits from FlashInferMLAAttnBackend, using FlashInferMLAAttnBackend for prefill and FlashMLABackend for decoding.
  • Add the create_flashmla_kv_indices_triton function to be compatible with the block table format of flashmla.

command

 python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote --tp 8 --page-size 64 --disable-cuda-graph --enable-flashmla

Todo

  • Support FlashMLA decode with cudagraph
  • Enable speculative sampling in FlashMLA
  • Add unit test
  • Performance Analysis and Optimization
  • Integrate FA3 prefill

@sleepcoo sleepcoo changed the title Support flashmla Support FlashMLA backend Mar 16, 2025
@sleepcoo sleepcoo marked this pull request as ready for review March 16, 2025 10:09
@zhyncs zhyncs merged commit a53fe42 into sgl-project:main Mar 16, 2025
assert self.chunked_prefill_size % self.page_size == 0

if self.enable_flashmla is True:
assert self.page_size == 64, "FlashMLA only support page_size=64"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

automatically set this

flashmla_index = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device
)
create_flashmla_kv_indices_triton[(bs,)](
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this metadata is the same for all layers, so we can process them only once in init_forward_metadata

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix the issue in the PR of CUDA graph.

@sleepcoo sleepcoo deleted the support-flashmla branch March 17, 2025 03:34
@lishicheng1996
Copy link
Copy Markdown

lishicheng1996 commented Mar 17, 2025

Hi, I test --enable-flashmla on H20*16, and it is slower than normal version. Do you have data about E2E inference speed gain with flashmla? Thanks very much!

@sleepcoo
Copy link
Copy Markdown
Collaborator Author

Hi, I test --enable-flashmla on H20*16, and it become slower than normal version. Do you have data about E2E inference speed gain with flashmla? Thanks very much!

Wait for my cudagraph implementation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants