|
21 | 21 | import torch.distributed._symmetric_memory as symm_mem |
22 | 22 |
|
23 | 23 | import helion |
24 | | -from helion.autotuner.base_search import _assert_close as assert_close_with_mismatch_tolerance |
25 | 24 | from helion._testing import DEVICE |
26 | 25 | from helion._testing import run_example |
| 26 | +from helion.autotuner.base_search import ( |
| 27 | + _assert_close as assert_close_with_mismatch_tolerance, |
| 28 | +) |
27 | 29 | import helion.language as hl |
28 | 30 | from helion.runtime.dist_utils import symm_mem_sync |
29 | 31 |
|
|
33 | 35 | max_mismatch_pct=1e-3, |
34 | 36 | ) |
35 | 37 |
|
| 38 | +config = helion.Config( |
| 39 | + block_sizes=[64, 64, 32], # M, N, K |
| 40 | + num_warps=8, |
| 41 | + num_stages=3, |
| 42 | +) |
| 43 | + |
| 44 | +# config = helion.Config(block_sizes=[64, 128, 128], indexing=['pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[1], load_eviction_policies=['last', '', '', '', '', 'first', '', '', '', '', '', 'last'], loop_orders=[[0, 1]], num_sm_multiplier=2, num_stages=1, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, None], range_multi_buffers=[True, None], range_unroll_factors=[3, 0], range_warp_specializes=[]) |
| 45 | + |
36 | 46 |
|
37 | 47 | @helion.kernel( |
38 | | - config=helion.Config( |
39 | | - block_sizes=[64, 64, 32], # M, N, K |
40 | | - num_warps=8, |
41 | | - num_stages=3, |
42 | | - ), |
| 48 | + config=config, |
43 | 49 | static_shapes=True, |
44 | 50 | ignore_warnings=[helion.exc.TensorOperationInWrapper], |
45 | 51 | autotune_baseline_accuracy_check_fn=functools.partial( |
@@ -82,8 +88,10 @@ def fp8_matmul_reduce_scatter_kernel( |
82 | 88 | acc = hl.dot(a[tile_m, tile_k], b[tile_k, tile_n], acc=acc) |
83 | 89 |
|
84 | 90 | # Apply per-row and per-column scales |
85 | | - acc = acc * scale_a[tile_m, :].to(torch.float32) * scale_b[:, tile_n].to( |
86 | | - torch.float32 |
| 91 | + acc = ( |
| 92 | + acc |
| 93 | + * scale_a[tile_m, :].to(torch.float32) |
| 94 | + * scale_b[:, tile_n].to(torch.float32) |
87 | 95 | ) |
88 | 96 |
|
89 | 97 | # Store bfloat16 partial result to this rank's symmetric-memory buffer |
@@ -165,7 +173,9 @@ def reference_fp8_matmul_reduce_scatter( |
165 | 173 | if group is None: |
166 | 174 | raise RuntimeError("Distributed group is not initialized") |
167 | 175 |
|
168 | | - c = torch._scaled_mm(a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16) |
| 176 | + c = torch._scaled_mm( |
| 177 | + a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16 |
| 178 | + ) |
169 | 179 |
|
170 | 180 | world_size = dist.get_world_size(group) |
171 | 181 | M_scatter = c.shape[0] // world_size |
@@ -223,7 +233,10 @@ def test(M: int, N: int, K: int, device: torch.device) -> None: |
223 | 233 |
|
224 | 234 | run_example( |
225 | 235 | functools.partial(helion_fp8_matmul_reduce_scatter, symm_mem_buffer), |
226 | | - {"nccl+cublas": reference_fp8_matmul_reduce_scatter, "fused_baseline": reference_fused_scaled_matmul_reduce_scatter}, |
| 236 | + { |
| 237 | + "nccl+cublas": reference_fp8_matmul_reduce_scatter, |
| 238 | + "fused_baseline": reference_fused_scaled_matmul_reduce_scatter, |
| 239 | + }, |
227 | 240 | (a, b, scale_a, scale_b), |
228 | 241 | **tolerance, |
229 | 242 | ) |
|
0 commit comments