|
| 1 | +""" |
| 2 | +FP8 MatMul + Reduce-Scatter Fusion Example |
| 3 | +========================================== |
| 4 | +This example extends the matmul_reduce_scatter example to use FP8 inputs. |
| 5 | +Each rank holds FP8 A and B shards; the kernel computes a local FP8 GEMM |
| 6 | +(accumulating in FP32 via ``hl.dot``), writes the float16 partial result to |
| 7 | +a symmetric-memory buffer, performs an intra-group barrier, and then |
| 8 | +reduce-scatters: each rank accumulates the rows it owns from all peers' |
| 9 | +buffers, producing a ``[M//WORLD_SIZE, N]`` float16 output. |
| 10 | +""" |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import functools |
| 15 | +import os |
| 16 | + |
| 17 | +import torch |
| 18 | +from torch._C._distributed_c10d import _SymmetricMemory |
| 19 | +import torch.distributed as dist |
| 20 | +import torch.distributed._symmetric_memory as symm_mem |
| 21 | + |
| 22 | +import helion |
| 23 | +from helion._testing import DEVICE |
| 24 | +from helion._testing import run_example |
| 25 | +import helion.language as hl |
| 26 | +from helion.runtime.dist_utils import symm_mem_sync |
| 27 | + |
| 28 | + |
| 29 | +@helion.kernel( |
| 30 | + config=helion.Config( |
| 31 | + block_sizes=[64, 64, 32], # M, N, K |
| 32 | + num_warps=8, |
| 33 | + num_stages=3, |
| 34 | + ), |
| 35 | + static_shapes=True, |
| 36 | + ignore_warnings=[helion.exc.TensorOperationInWrapper], |
| 37 | +) |
| 38 | +def fp8_matmul_reduce_scatter_kernel( |
| 39 | + a: torch.Tensor, # [M, K] float8_e4m3fn |
| 40 | + b: torch.Tensor, # [K, N] float8_e4m3fn |
| 41 | + symm_mem_buffer: torch.Tensor, # [M, N] float16, symmetric memory |
| 42 | + signal_pad_ptrs: torch.Tensor, |
| 43 | + RANK: hl.constexpr, |
| 44 | + WORLD_SIZE: hl.constexpr, |
| 45 | + GROUP_NAME: hl.ProcessGroupName, |
| 46 | +) -> torch.Tensor: |
| 47 | + """ |
| 48 | + Fused FP8 MatMul + Reduce-Scatter kernel. |
| 49 | +
|
| 50 | + Computes ``(A @ B).to(float16)`` in a distributed reduce-scatter pattern: |
| 51 | + each rank emits only its ``M // WORLD_SIZE`` output rows. |
| 52 | + """ |
| 53 | + M, K = a.size() |
| 54 | + K2, N = b.size() |
| 55 | + M_scatter = M // WORLD_SIZE # type: ignore[unsupported-operation] |
| 56 | + |
| 57 | + output = torch.empty([M_scatter, N], dtype=torch.float16, device=a.device) |
| 58 | + |
| 59 | + buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, GROUP_NAME) |
| 60 | + |
| 61 | + scatter_start = RANK * M_scatter # type: ignore[unsupported-operation] |
| 62 | + scatter_end = scatter_start + M_scatter # type: ignore[unsupported-operation] |
| 63 | + |
| 64 | + for tile_m, tile_n in hl.tile([M, N]): |
| 65 | + # FP8 GEMM tile, accumulating in FP32 |
| 66 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 67 | + for tile_k in hl.tile(K): |
| 68 | + acc = hl.dot(a[tile_m, tile_k], b[tile_k, tile_n], acc=acc) |
| 69 | + |
| 70 | + # Store float16 partial result to this rank's symmetric-memory buffer |
| 71 | + symm_mem_buffer[tile_m, tile_n] = acc.to(torch.float16) |
| 72 | + |
| 73 | + # Barrier: release our write, acquire peers' writes |
| 74 | + hl.triton_kernel( |
| 75 | + symm_mem_sync, |
| 76 | + args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True), |
| 77 | + output_like=None, |
| 78 | + ) |
| 79 | + |
| 80 | + # Reduce-scatter: accumulate only the rows this rank owns |
| 81 | + if tile_m.begin >= scatter_start and tile_m.begin < scatter_end: # type: ignore[unsupported-operation] |
| 82 | + acc_reduce = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 83 | + for remote_buffer in buffer_tuple: |
| 84 | + acc_reduce = acc_reduce + remote_buffer[tile_m, tile_n].to( |
| 85 | + torch.float32 |
| 86 | + ) |
| 87 | + output[tile_m.index - scatter_start, tile_n] = acc_reduce.to(torch.float16) # type: ignore[unsupported-operation] |
| 88 | + |
| 89 | + # Final barrier (release only) |
| 90 | + hl.triton_kernel( |
| 91 | + symm_mem_sync, |
| 92 | + args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, False), |
| 93 | + output_like=None, |
| 94 | + ) |
| 95 | + |
| 96 | + return output |
| 97 | + |
| 98 | + |
| 99 | +def helion_fp8_matmul_reduce_scatter( |
| 100 | + symm_mem_buffer: torch.Tensor, |
| 101 | + a: torch.Tensor, |
| 102 | + b: torch.Tensor, |
| 103 | +) -> torch.Tensor: |
| 104 | + """ |
| 105 | + Wrapper that rendezvouss on the pre-allocated symmetric buffer and |
| 106 | + invokes the FP8 reduce-scatter kernel. |
| 107 | +
|
| 108 | + Args: |
| 109 | + symm_mem_buffer: Pre-allocated symmetric-memory buffer ``[M, N]`` float16. |
| 110 | + a: Local FP8 A shard ``[M, K]`` (``torch.float8_e4m3fn``). |
| 111 | + b: Local FP8 B shard ``[K, N]`` (``torch.float8_e4m3fn``). |
| 112 | + """ |
| 113 | + group = dist.group.WORLD |
| 114 | + if group is None: |
| 115 | + raise RuntimeError("Distributed group is not initialized") |
| 116 | + |
| 117 | + symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, group.group_name) |
| 118 | + |
| 119 | + return fp8_matmul_reduce_scatter_kernel( |
| 120 | + a, |
| 121 | + b, |
| 122 | + symm_mem_buffer, |
| 123 | + symm_mem_hdl.signal_pad_ptrs_dev, |
| 124 | + RANK=symm_mem_hdl.rank, |
| 125 | + WORLD_SIZE=symm_mem_hdl.world_size, |
| 126 | + GROUP_NAME=group.group_name, |
| 127 | + ) |
| 128 | + |
| 129 | + |
| 130 | +def reference_fp8_matmul_reduce_scatter( |
| 131 | + a: torch.Tensor, |
| 132 | + b: torch.Tensor, |
| 133 | +) -> torch.Tensor: |
| 134 | + """ |
| 135 | + Reference: dequantize to float32, matmul, reduce-scatter along M. |
| 136 | + """ |
| 137 | + group = dist.group.WORLD |
| 138 | + if group is None: |
| 139 | + raise RuntimeError("Distributed group is not initialized") |
| 140 | + |
| 141 | + c = torch.mm(a.to(torch.float32), b.to(torch.float32)).to(torch.float16) |
| 142 | + |
| 143 | + world_size = dist.get_world_size(group) |
| 144 | + M_scatter = c.shape[0] // world_size |
| 145 | + output = torch.empty(M_scatter, c.shape[1], dtype=c.dtype, device=c.device) |
| 146 | + dist.reduce_scatter_tensor(output, c, group=group) |
| 147 | + return output |
| 148 | + |
| 149 | + |
| 150 | +def test(M: int, N: int, K: int, device: torch.device) -> None: |
| 151 | + """Test the FP8 reduce-scatter kernel against the reference.""" |
| 152 | + rank = dist.get_rank() |
| 153 | + |
| 154 | + torch.manual_seed(42 + rank) |
| 155 | + a_fp32 = torch.randn(M, K, device=device) |
| 156 | + a = a_fp32.to(torch.float8_e4m3fn) |
| 157 | + |
| 158 | + torch.manual_seed(42) |
| 159 | + b_fp32 = torch.randn(K, N, device=device) |
| 160 | + b = b_fp32.to(torch.float8_e4m3fn) |
| 161 | + |
| 162 | + symm_mem_buffer = symm_mem.empty(M, N, dtype=torch.float16, device=device) |
| 163 | + symm_mem.rendezvous(symm_mem_buffer, dist.group.WORLD.group_name) # type: ignore[union-attr] |
| 164 | + |
| 165 | + run_example( |
| 166 | + functools.partial(helion_fp8_matmul_reduce_scatter, symm_mem_buffer), |
| 167 | + reference_fp8_matmul_reduce_scatter, |
| 168 | + (a, b), |
| 169 | + rtol=2e-1, |
| 170 | + atol=2e-1, |
| 171 | + ) |
| 172 | + |
| 173 | + |
| 174 | +def main() -> None: |
| 175 | + _SymmetricMemory.signal_pad_size = 1024 * 1024 * 16 |
| 176 | + rank = int(os.environ["LOCAL_RANK"]) |
| 177 | + torch.manual_seed(42 + rank) |
| 178 | + device = torch.device(f"cuda:{rank}") |
| 179 | + torch.cuda.set_device(device) |
| 180 | + dist.init_process_group("nccl") |
| 181 | + |
| 182 | + test(M=512, N=768, K=1024, device=device) |
| 183 | + |
| 184 | + dist.destroy_process_group() |
| 185 | + |
| 186 | + |
| 187 | +if __name__ == "__main__": |
| 188 | + """ |
| 189 | + Run with: |
| 190 | + python -m torch.distributed.run --standalone \\ |
| 191 | + --nproc-per-node 4 \\ |
| 192 | + --rdzv-backend c10d --rdzv-endpoint localhost:0 \\ |
| 193 | + examples/distributed/fp8_matmul_reduce_scatter.py |
| 194 | + """ |
| 195 | + assert DEVICE.type == "cuda", "Requires CUDA device" |
| 196 | + main() |
0 commit comments