|
| 1 | +""" |
| 2 | +2D-Parallel Matrix Multiplication Example |
| 3 | +========================================== |
| 4 | +Demonstrates two independent dimensions of parallelism in a single fused kernel: |
| 5 | +
|
| 6 | +* **Sequence parallelism (SP)**: each rank owns a disjoint shard of input rows. |
| 7 | + SP ranks never communicate with each other – they produce fully independent |
| 8 | + output rows. |
| 9 | +
|
| 10 | +* **Tensor parallelism (TP)**: within each SP group, every rank holds a shard of |
| 11 | + the K (inner/reduction) dimension. Partial products must be summed via an |
| 12 | + intra-TP all-reduce implemented with symmetric-memory signal pads. |
| 13 | +
|
| 14 | +Rank layout (example: SP=2, TP=2, total 4 GPUs):: |
| 15 | +
|
| 16 | + TP=0 TP=1 |
| 17 | +SP=0: rank 0 rank 1 <- compute rows 0 .. M/2 of output |
| 18 | +SP=1: rank 2 rank 3 <- compute rows M/2 .. M of output |
| 19 | +
|
| 20 | +Each rank ``(sp, tp)`` holds: |
| 21 | +
|
| 22 | +* ``a_local`` ``[M/SP, K/TP]`` – its row-shard × K-shard of **A** |
| 23 | +* ``b_local`` ``[K/TP, N ]`` – its K-shard of **B** (full N on every rank) |
| 24 | +
|
| 25 | +The kernel steps for every output tile ``[tile_m, tile_n]``: |
| 26 | +
|
| 27 | +1. Compute partial GEMM: ``a_local @ b_local`` → partial ``[M/SP, N]`` |
| 28 | +2. Write partial result to a symmetric-memory buffer (visible to TP peers). |
| 29 | +3. Intra-TP barrier (release + acquire) so every peer can read the partial. |
| 30 | +4. Sum all TP peers' partials in-kernel (fused all-reduce over the TP group). |
| 31 | +5. Intra-TP release barrier to allow cleanup / the next tile. |
| 32 | +""" |
| 33 | + |
| 34 | +from __future__ import annotations |
| 35 | + |
| 36 | +import functools |
| 37 | +import os |
| 38 | + |
| 39 | +import torch |
| 40 | +from torch._C._distributed_c10d import _SymmetricMemory |
| 41 | +import torch.distributed as dist |
| 42 | +import torch.distributed._symmetric_memory as symm_mem |
| 43 | +from torch.distributed.device_mesh import init_device_mesh |
| 44 | + |
| 45 | +import helion |
| 46 | +from helion._testing import DEVICE |
| 47 | +from helion._testing import run_example |
| 48 | +import helion.language as hl |
| 49 | +from helion.runtime.dist_utils import symm_mem_sync |
| 50 | + |
| 51 | + |
| 52 | +@helion.kernel( |
| 53 | + config=helion.Config( |
| 54 | + block_sizes=[64, 64, 32], |
| 55 | + num_warps=8, |
| 56 | + num_stages=3, |
| 57 | + indexing="block_ptr", |
| 58 | + ), |
| 59 | + static_shapes=True, |
| 60 | + ignore_warnings=[helion.exc.TensorOperationInWrapper], |
| 61 | +) |
| 62 | +def two_dim_parallel_matmul_kernel( |
| 63 | + a_local: torch.Tensor, # [M/SP, K/TP] |
| 64 | + b_local: torch.Tensor, # [K/TP, N ] |
| 65 | + symm_mem_buf: torch.Tensor, # [M/SP, N ] symmetric-memory scratch |
| 66 | + signal_pad_ptrs: torch.Tensor, |
| 67 | + TP_RANK: hl.constexpr, |
| 68 | + TP_SIZE: hl.constexpr, |
| 69 | + GROUP_NAME: hl.ProcessGroupName, |
| 70 | +) -> torch.Tensor: |
| 71 | + """ |
| 72 | + Fused 2D-parallel (SP × TP) matmul kernel. |
| 73 | +
|
| 74 | + Dimension 1 – Sequence Parallel (SP): different ranks own different M rows. |
| 75 | + No communication occurs across SP ranks; each computes its rows independently. |
| 76 | +
|
| 77 | + Dimension 2 – Tensor Parallel (TP): ranks in the same SP group each own a |
| 78 | + K-shard of A and B. After the local partial GEMM, an in-kernel all-reduce |
| 79 | + over the TP group produces the correct output for this rank's M rows. |
| 80 | + """ |
| 81 | + M_local, K_local = a_local.size() |
| 82 | + N = b_local.size(1) |
| 83 | + out = torch.empty([M_local, N], dtype=a_local.dtype, device=a_local.device) |
| 84 | + |
| 85 | + # Symmetric-memory views of every TP peer's scratch buffer. |
| 86 | + remote_bufs = torch.ops.symm_mem.get_remote_tensors(symm_mem_buf, GROUP_NAME) |
| 87 | + |
| 88 | + for tile_m, tile_n in hl.tile([M_local, N]): |
| 89 | + # ── Dimension 2 (Tensor Parallel): partial GEMM over local K shard ── |
| 90 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 91 | + for tile_k in hl.tile(K_local): |
| 92 | + acc = torch.addmm(acc, a_local[tile_m, tile_k], b_local[tile_k, tile_n]) |
| 93 | + |
| 94 | + # Write partial result to this rank's symmetric-memory buffer. |
| 95 | + symm_mem_buf[tile_m, tile_n] = acc.to(a_local.dtype) |
| 96 | + |
| 97 | + # Barrier: release our write so peers can see it; acquire their writes. |
| 98 | + hl.triton_kernel( |
| 99 | + symm_mem_sync, |
| 100 | + args=(signal_pad_ptrs, None, TP_RANK, TP_SIZE, True, True), |
| 101 | + output_like=None, |
| 102 | + ) |
| 103 | + |
| 104 | + # ── Reduce: sum partials from all TP peers (fused intra-TP all-reduce) ── |
| 105 | + acc_full = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 106 | + for peer_buf in remote_bufs: |
| 107 | + acc_full = acc_full + peer_buf[tile_m, tile_n].to(torch.float32) |
| 108 | + out[tile_m, tile_n] = acc_full.to(a_local.dtype) |
| 109 | + |
| 110 | + # Release barrier: signal that we are done reading the shared buffers. |
| 111 | + hl.triton_kernel( |
| 112 | + symm_mem_sync, |
| 113 | + args=(signal_pad_ptrs, None, TP_RANK, TP_SIZE, True, False), |
| 114 | + output_like=None, |
| 115 | + ) |
| 116 | + |
| 117 | + return out |
| 118 | + |
| 119 | + |
| 120 | +def helion_two_dim_parallel_matmul( |
| 121 | + a_local: torch.Tensor, |
| 122 | + b_local: torch.Tensor, |
| 123 | + tp_group: dist.ProcessGroup, |
| 124 | + symm_mem_buf: torch.Tensor, |
| 125 | +) -> torch.Tensor: |
| 126 | + """ |
| 127 | + Allocate symmetric memory for the TP group and invoke the 2D-parallel kernel. |
| 128 | +
|
| 129 | + Args: |
| 130 | + a_local: Local A shard ``[M/SP, K/TP]``. |
| 131 | + b_local: Local B shard ``[K/TP, N ]``. |
| 132 | + tp_group: Intra-TP process group for this rank. |
| 133 | + """ |
| 134 | + group_name = tp_group.group_name # type: ignore[missing-attribute] |
| 135 | + |
| 136 | + hdl = symm_mem.rendezvous(symm_mem_buf, group_name) |
| 137 | + |
| 138 | + return two_dim_parallel_matmul_kernel( |
| 139 | + a_local, |
| 140 | + b_local, |
| 141 | + symm_mem_buf, |
| 142 | + hdl.signal_pad_ptrs_dev, |
| 143 | + TP_RANK=hdl.rank, |
| 144 | + TP_SIZE=hdl.world_size, |
| 145 | + GROUP_NAME=group_name, |
| 146 | + ) |
| 147 | + |
| 148 | + |
| 149 | +def reference_two_dim_parallel_matmul( |
| 150 | + a_local: torch.Tensor, |
| 151 | + b_local: torch.Tensor, |
| 152 | + tp_group: dist.ProcessGroup, |
| 153 | +) -> torch.Tensor: |
| 154 | + """ |
| 155 | + Reference: all-gather K shards within the TP group, then do a local matmul. |
| 156 | + """ |
| 157 | + tp_size = dist.get_world_size(tp_group) |
| 158 | + |
| 159 | + # All-gather the K-sharded a_local across the TP group → [M/SP, K]. |
| 160 | + a_parts = [torch.empty_like(a_local) for _ in range(tp_size)] |
| 161 | + dist.all_gather(a_parts, a_local, group=tp_group) |
| 162 | + a_full = torch.cat(a_parts, dim=1) |
| 163 | + |
| 164 | + # All-gather the K-sharded b_local across the TP group → [K, N]. |
| 165 | + b_parts = [torch.empty_like(b_local) for _ in range(tp_size)] |
| 166 | + dist.all_gather(b_parts, b_local, group=tp_group) |
| 167 | + b_full = torch.cat(b_parts, dim=0) |
| 168 | + |
| 169 | + return torch.mm(a_full.float(), b_full.float()).to(a_local.dtype) |
| 170 | + |
| 171 | + |
| 172 | +def test( |
| 173 | + M: int, |
| 174 | + K: int, |
| 175 | + N: int, |
| 176 | + device: torch.device, |
| 177 | + dtype: torch.dtype, |
| 178 | + tp_group: dist.ProcessGroup, |
| 179 | + sp_rank: int, |
| 180 | + tp_size: int, |
| 181 | + sp_size: int, |
| 182 | +) -> None: |
| 183 | + """Test the 2D-parallel kernel against the reference.""" |
| 184 | + tp_rank = dist.get_rank(tp_group) |
| 185 | + torch.manual_seed(42 + sp_rank * tp_size + tp_rank) |
| 186 | + |
| 187 | + M_local = M // sp_size |
| 188 | + K_local = K // tp_size |
| 189 | + |
| 190 | + # Each (sp, tp) rank owns a unique [M/SP, K/TP] block of A. |
| 191 | + a_local = torch.randn(M_local, K_local, dtype=dtype, device=device) |
| 192 | + |
| 193 | + # Each TP rank owns its K-shard of B with the full N dimension. |
| 194 | + # Ranks in the same TP group hold different K-shards; ranks in different SP |
| 195 | + # groups but the same TP rank hold the *same* K-shard of B. |
| 196 | + torch.manual_seed(42 + tp_rank) |
| 197 | + b_local = torch.randn(K_local, N, dtype=dtype, device=device) |
| 198 | + |
| 199 | + symm_mem_buf = symm_mem.empty( |
| 200 | + M_local, N, dtype=a_local.dtype, device=a_local.device |
| 201 | + ) |
| 202 | + symm_mem.rendezvous(symm_mem_buf, tp_group.group_name) |
| 203 | + |
| 204 | + _helion_two_dim_parallel_matmul = functools.partial( |
| 205 | + helion_two_dim_parallel_matmul, symm_mem_buf=symm_mem_buf |
| 206 | + ) |
| 207 | + |
| 208 | + run_example( |
| 209 | + lambda a, b: _helion_two_dim_parallel_matmul(a, b, tp_group), |
| 210 | + lambda a, b: reference_two_dim_parallel_matmul(a, b, tp_group), |
| 211 | + (a_local, b_local), |
| 212 | + rtol=2e-1, |
| 213 | + atol=2e-1, |
| 214 | + process_group_name=tp_group.group_name, |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +def main() -> None: |
| 219 | + """ |
| 220 | + Initialize a 2-D device mesh, enable symmetric memory for the TP groups, |
| 221 | + and run the 2D-parallel matmul test. |
| 222 | + """ |
| 223 | + _SymmetricMemory.signal_pad_size = 1024 * 1024 * 16 |
| 224 | + |
| 225 | + rank = int(os.environ["LOCAL_RANK"]) |
| 226 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 227 | + assert world_size >= 4 and world_size % 2 == 0, ( |
| 228 | + f"Requires at least 4 GPUs arranged as SP×TP mesh (got {world_size})" |
| 229 | + ) |
| 230 | + |
| 231 | + torch.manual_seed(42 + rank) |
| 232 | + device = torch.device(f"cuda:{rank}") |
| 233 | + torch.cuda.set_device(device) |
| 234 | + dist.init_process_group("nccl") |
| 235 | + |
| 236 | + # Build a 2-D device mesh: SP (outer dim, row parallelism) × TP (inner dim, |
| 237 | + # K-reduction parallelism). |
| 238 | + tp_size = 2 |
| 239 | + sp_size = world_size // tp_size |
| 240 | + mesh = init_device_mesh( |
| 241 | + "cuda", |
| 242 | + (sp_size, tp_size), |
| 243 | + mesh_dim_names=("sp", "tp"), |
| 244 | + ) |
| 245 | + tp_group = mesh.get_group("tp") |
| 246 | + sp_rank = rank // tp_size |
| 247 | + |
| 248 | + test( |
| 249 | + M=1024, |
| 250 | + K=256, |
| 251 | + N=512, |
| 252 | + device=device, |
| 253 | + dtype=torch.float32, |
| 254 | + tp_group=tp_group, |
| 255 | + sp_rank=sp_rank, |
| 256 | + tp_size=tp_size, |
| 257 | + sp_size=sp_size, |
| 258 | + ) |
| 259 | + |
| 260 | + dist.destroy_process_group() |
| 261 | + |
| 262 | + |
| 263 | +if __name__ == "__main__": |
| 264 | + """ |
| 265 | + Run with: |
| 266 | + python -m torch.distributed.run --standalone \\ |
| 267 | + --nproc-per-node 4 \\ |
| 268 | + --rdzv-backend c10d --rdzv-endpoint localhost:0 \\ |
| 269 | + examples/distributed/two_dim_parallel_matmul.py |
| 270 | + """ |
| 271 | + assert DEVICE.type == "cuda", "Requires CUDA device" |
| 272 | + main() |
0 commit comments