|
1 | 1 | import itertools |
| 2 | +import os |
2 | 3 | import unittest |
3 | 4 |
|
4 | 5 | import torch |
|
11 | 12 | w8a8_block_fp8_matmul, |
12 | 13 | ) |
13 | 14 |
|
| 15 | +_is_cuda = torch.cuda.is_available() and torch.version.cuda |
| 16 | + |
14 | 17 |
|
15 | 18 | # For test |
16 | 19 | def native_per_token_group_quant_fp8( |
@@ -208,21 +211,44 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl |
208 | 211 |
|
209 | 212 |
|
210 | 213 | class TestW8A8BlockFP8Matmul(unittest.TestCase): |
211 | | - OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] |
212 | | - M = [1, 7, 83, 512, 2048] |
213 | | - N = [128, 512, 1024, 4096, 7748, 13824] |
214 | | - K = [256, 4096, 5120, 3884, 13824] |
215 | | - # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] |
216 | | - BLOCK_SIZE = [[128, 128]] |
217 | | - SEEDS = [0] |
| 214 | + |
| 215 | + if not _is_cuda: |
| 216 | + OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] |
| 217 | + M = [1, 7, 83, 512, 2048] |
| 218 | + NKs = [ |
| 219 | + (N, K) |
| 220 | + for N in [128, 512, 1024, 4096, 7748, 13824] |
| 221 | + for K in [256, 4096, 5120, 3884, 13824] |
| 222 | + ] |
| 223 | + # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] |
| 224 | + BLOCK_SIZE = [[128, 128]] |
| 225 | + SEEDS = [0] |
| 226 | + else: |
| 227 | + # use practical shape in DeepSeek V3 for test |
| 228 | + OUT_DTYPES = [torch.bfloat16] |
| 229 | + M = [64, 128, 512, 1024, 4096] |
| 230 | + NKs = [ |
| 231 | + (1536, 7168), |
| 232 | + (3072, 1536), |
| 233 | + (24576, 7168), |
| 234 | + (4096, 512), |
| 235 | + (7168, 2048), |
| 236 | + (4608, 7168), |
| 237 | + (512, 7168), |
| 238 | + (7168, 2304), |
| 239 | + (7168, 512), |
| 240 | + ] |
| 241 | + BLOCK_SIZE = [[128, 128]] |
| 242 | + SEEDS = [0] |
218 | 243 |
|
219 | 244 | @classmethod |
220 | 245 | def setUpClass(cls): |
221 | 246 | if not torch.cuda.is_available(): |
222 | 247 | raise unittest.SkipTest("CUDA is not available") |
223 | 248 | torch.set_default_device("cuda") |
224 | 249 |
|
225 | | - def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed): |
| 250 | + def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed): |
| 251 | + N, K = NK |
226 | 252 | torch.manual_seed(seed) |
227 | 253 | # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half |
228 | 254 | factor_for_scale = 1e-2 |
@@ -257,19 +283,17 @@ def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed): |
257 | 283 | def test_w8a8_block_fp8_matmul(self): |
258 | 284 | for params in itertools.product( |
259 | 285 | self.M, |
260 | | - self.N, |
261 | | - self.K, |
| 286 | + self.NKs, |
262 | 287 | self.BLOCK_SIZE, |
263 | 288 | self.OUT_DTYPES, |
264 | 289 | self.SEEDS, |
265 | 290 | ): |
266 | 291 | with self.subTest( |
267 | 292 | M=params[0], |
268 | | - N=params[1], |
269 | | - K=params[2], |
270 | | - block_size=params[3], |
271 | | - out_dtype=params[4], |
272 | | - seed=params[5], |
| 293 | + NKs=params[1], |
| 294 | + block_size=params[2], |
| 295 | + out_dtype=params[3], |
| 296 | + seed=params[4], |
273 | 297 | ): |
274 | 298 | self._w8a8_block_fp8_matmul(*params) |
275 | 299 |
|
|
0 commit comments