Skip to content

Commit f45b90a

Browse files
sleepcooFlamingoPg
authored andcommitted
linear support deepgemm (sgl-project#4199)
Co-authored-by: yinfan98 <1106310035@qq.com>
1 parent 5f60399 commit f45b90a

3 files changed

Lines changed: 76 additions & 44 deletions

File tree

python/sglang/srt/layers/quantization/fp8_kernel.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929

3030
_is_cuda = torch.cuda.is_available() and torch.version.cuda
3131
if _is_cuda:
32+
import deep_gemm
3233
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
3334

3435
logger = logging.getLogger(__name__)
3536

37+
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
38+
3639

3740
@triton.jit
3841
def _per_token_group_quant_fp8(
@@ -722,34 +725,39 @@ def grid(META):
722725
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
723726
N, config["BLOCK_SIZE_N"]
724727
)
725-
kernel = (
726-
_w8a8_block_fp8_matmul_unrolledx4
727-
if (is_hip_ == True and num_workgroups <= get_device_core_count())
728-
else _w8a8_block_fp8_matmul
729-
)
730728

731-
kernel[grid](
732-
A,
733-
B,
734-
C,
735-
As,
736-
Bs,
737-
M,
738-
N,
739-
K,
740-
block_n,
741-
block_k,
742-
A.stride(-2),
743-
A.stride(-1),
744-
B.stride(1),
745-
B.stride(0),
746-
C.stride(-2),
747-
C.stride(-1),
748-
As.stride(-2),
749-
As.stride(-1),
750-
Bs.stride(1),
751-
Bs.stride(0),
752-
**config,
753-
)
729+
# deepgemm only support bf16
730+
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
731+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
732+
else:
733+
kernel = (
734+
_w8a8_block_fp8_matmul_unrolledx4
735+
if (is_hip_ == True and num_workgroups <= get_device_core_count())
736+
else _w8a8_block_fp8_matmul
737+
)
738+
739+
kernel[grid](
740+
A,
741+
B,
742+
C,
743+
As,
744+
Bs,
745+
M,
746+
N,
747+
K,
748+
block_n,
749+
block_k,
750+
A.stride(-2),
751+
A.stride(-1),
752+
B.stride(1),
753+
B.stride(0),
754+
C.stride(-2),
755+
C.stride(-1),
756+
As.stride(-2),
757+
As.stride(-1),
758+
Bs.stride(1),
759+
Bs.stride(0),
760+
**config,
761+
)
754762

755763
return C

python/sglang/test/test_block_fp8.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import os
23
import unittest
34

45
import torch
@@ -11,6 +12,8 @@
1112
w8a8_block_fp8_matmul,
1213
)
1314

15+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
16+
1417

1518
# For test
1619
def native_per_token_group_quant_fp8(
@@ -208,21 +211,44 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
208211

209212

210213
class TestW8A8BlockFP8Matmul(unittest.TestCase):
211-
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
212-
M = [1, 7, 83, 512, 2048]
213-
N = [128, 512, 1024, 4096, 7748, 13824]
214-
K = [256, 4096, 5120, 3884, 13824]
215-
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
216-
BLOCK_SIZE = [[128, 128]]
217-
SEEDS = [0]
214+
215+
if not _is_cuda:
216+
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
217+
M = [1, 7, 83, 512, 2048]
218+
NKs = [
219+
(N, K)
220+
for N in [128, 512, 1024, 4096, 7748, 13824]
221+
for K in [256, 4096, 5120, 3884, 13824]
222+
]
223+
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
224+
BLOCK_SIZE = [[128, 128]]
225+
SEEDS = [0]
226+
else:
227+
# use practical shape in DeepSeek V3 for test
228+
OUT_DTYPES = [torch.bfloat16]
229+
M = [64, 128, 512, 1024, 4096]
230+
NKs = [
231+
(1536, 7168),
232+
(3072, 1536),
233+
(24576, 7168),
234+
(4096, 512),
235+
(7168, 2048),
236+
(4608, 7168),
237+
(512, 7168),
238+
(7168, 2304),
239+
(7168, 512),
240+
]
241+
BLOCK_SIZE = [[128, 128]]
242+
SEEDS = [0]
218243

219244
@classmethod
220245
def setUpClass(cls):
221246
if not torch.cuda.is_available():
222247
raise unittest.SkipTest("CUDA is not available")
223248
torch.set_default_device("cuda")
224249

225-
def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
250+
def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
251+
N, K = NK
226252
torch.manual_seed(seed)
227253
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
228254
factor_for_scale = 1e-2
@@ -257,19 +283,17 @@ def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
257283
def test_w8a8_block_fp8_matmul(self):
258284
for params in itertools.product(
259285
self.M,
260-
self.N,
261-
self.K,
286+
self.NKs,
262287
self.BLOCK_SIZE,
263288
self.OUT_DTYPES,
264289
self.SEEDS,
265290
):
266291
with self.subTest(
267292
M=params[0],
268-
N=params[1],
269-
K=params[2],
270-
block_size=params[3],
271-
out_dtype=params[4],
272-
seed=params[5],
293+
NKs=params[1],
294+
block_size=params[2],
295+
out_dtype=params[3],
296+
seed=params[4],
273297
):
274298
self._w8a8_block_fp8_matmul(*params)
275299

test/srt/test_fp8_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def setUpClass(cls):
1717
cls.K = 512
1818
cls.group_size = 128
1919
cls.quant_type = torch.float8_e4m3fn
20-
cls.output_type = torch.float16
20+
cls.output_type = torch.bfloat16
2121

2222
@staticmethod
2323
def _make_A(M, K, group_size, out_dtype):

0 commit comments

Comments
 (0)