Skip to content

Commit 161d0a0

Browse files
committed
use the user provide process group for autotuning
stack-info: PR: #1823, branch: shunting314/stack/21
1 parent 0728d11 commit 161d0a0

18 files changed

+598
-84
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""
2+
2D-Parallel Matrix Multiplication Example
3+
==========================================
4+
Demonstrates two independent dimensions of parallelism in a single fused kernel:
5+
6+
* **Sequence parallelism (SP)**: each rank owns a disjoint shard of input rows.
7+
SP ranks never communicate with each other – they produce fully independent
8+
output rows.
9+
10+
* **Tensor parallelism (TP)**: within each SP group, every rank holds a shard of
11+
the K (inner/reduction) dimension. Partial products must be summed via an
12+
intra-TP all-reduce implemented with symmetric-memory signal pads.
13+
14+
Rank layout (example: SP=2, TP=2, total 4 GPUs)::
15+
16+
TP=0 TP=1
17+
SP=0: rank 0 rank 1 <- compute rows 0 .. M/2 of output
18+
SP=1: rank 2 rank 3 <- compute rows M/2 .. M of output
19+
20+
Each rank ``(sp, tp)`` holds:
21+
22+
* ``a_local`` ``[M/SP, K/TP]`` – its row-shard × K-shard of **A**
23+
* ``b_local`` ``[K/TP, N ]`` – its K-shard of **B** (full N on every rank)
24+
25+
The kernel steps for every output tile ``[tile_m, tile_n]``:
26+
27+
1. Compute partial GEMM: ``a_local @ b_local`` → partial ``[M/SP, N]``
28+
2. Write partial result to a symmetric-memory buffer (visible to TP peers).
29+
3. Intra-TP barrier (release + acquire) so every peer can read the partial.
30+
4. Sum all TP peers' partials in-kernel (fused all-reduce over the TP group).
31+
5. Intra-TP release barrier to allow cleanup / the next tile.
32+
"""
33+
34+
from __future__ import annotations
35+
36+
import functools
37+
import os
38+
39+
import torch
40+
from torch._C._distributed_c10d import _SymmetricMemory
41+
import torch.distributed as dist
42+
import torch.distributed._symmetric_memory as symm_mem
43+
from torch.distributed.device_mesh import init_device_mesh
44+
45+
import helion
46+
from helion._testing import DEVICE
47+
from helion._testing import run_example
48+
import helion.language as hl
49+
from helion.runtime.dist_utils import symm_mem_sync
50+
51+
52+
@helion.kernel(
53+
config=helion.Config(
54+
block_sizes=[64, 64, 32],
55+
num_warps=8,
56+
num_stages=3,
57+
indexing="block_ptr",
58+
),
59+
static_shapes=True,
60+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
61+
)
62+
def two_dim_parallel_matmul_kernel(
63+
a_local: torch.Tensor, # [M/SP, K/TP]
64+
b_local: torch.Tensor, # [K/TP, N ]
65+
symm_mem_buf: torch.Tensor, # [M/SP, N ] symmetric-memory scratch
66+
signal_pad_ptrs: torch.Tensor,
67+
TP_RANK: hl.constexpr,
68+
TP_SIZE: hl.constexpr,
69+
GROUP_NAME: hl.ProcessGroupName,
70+
) -> torch.Tensor:
71+
"""
72+
Fused 2D-parallel (SP × TP) matmul kernel.
73+
74+
Dimension 1 – Sequence Parallel (SP): different ranks own different M rows.
75+
No communication occurs across SP ranks; each computes its rows independently.
76+
77+
Dimension 2 – Tensor Parallel (TP): ranks in the same SP group each own a
78+
K-shard of A and B. After the local partial GEMM, an in-kernel all-reduce
79+
over the TP group produces the correct output for this rank's M rows.
80+
"""
81+
M_local, K_local = a_local.size()
82+
N = b_local.size(1)
83+
out = torch.empty([M_local, N], dtype=a_local.dtype, device=a_local.device)
84+
85+
# Symmetric-memory views of every TP peer's scratch buffer.
86+
remote_bufs = torch.ops.symm_mem.get_remote_tensors(symm_mem_buf, GROUP_NAME)
87+
88+
for tile_m, tile_n in hl.tile([M_local, N]):
89+
# ── Dimension 2 (Tensor Parallel): partial GEMM over local K shard ──
90+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
91+
for tile_k in hl.tile(K_local):
92+
acc = torch.addmm(acc, a_local[tile_m, tile_k], b_local[tile_k, tile_n])
93+
94+
# Write partial result to this rank's symmetric-memory buffer.
95+
symm_mem_buf[tile_m, tile_n] = acc.to(a_local.dtype)
96+
97+
# Barrier: release our write so peers can see it; acquire their writes.
98+
hl.triton_kernel(
99+
symm_mem_sync,
100+
args=(signal_pad_ptrs, None, TP_RANK, TP_SIZE, True, True),
101+
output_like=None,
102+
)
103+
104+
# ── Reduce: sum partials from all TP peers (fused intra-TP all-reduce) ──
105+
acc_full = hl.zeros([tile_m, tile_n], dtype=torch.float32)
106+
for peer_buf in remote_bufs:
107+
acc_full = acc_full + peer_buf[tile_m, tile_n].to(torch.float32)
108+
out[tile_m, tile_n] = acc_full.to(a_local.dtype)
109+
110+
# Release barrier: signal that we are done reading the shared buffers.
111+
hl.triton_kernel(
112+
symm_mem_sync,
113+
args=(signal_pad_ptrs, None, TP_RANK, TP_SIZE, True, False),
114+
output_like=None,
115+
)
116+
117+
return out
118+
119+
120+
def helion_two_dim_parallel_matmul(
121+
a_local: torch.Tensor,
122+
b_local: torch.Tensor,
123+
tp_group: dist.ProcessGroup,
124+
symm_mem_buf: torch.Tensor,
125+
) -> torch.Tensor:
126+
"""
127+
Allocate symmetric memory for the TP group and invoke the 2D-parallel kernel.
128+
129+
Args:
130+
a_local: Local A shard ``[M/SP, K/TP]``.
131+
b_local: Local B shard ``[K/TP, N ]``.
132+
tp_group: Intra-TP process group for this rank.
133+
"""
134+
group_name = tp_group.group_name # type: ignore[missing-attribute]
135+
136+
hdl = symm_mem.rendezvous(symm_mem_buf, group_name)
137+
138+
return two_dim_parallel_matmul_kernel(
139+
a_local,
140+
b_local,
141+
symm_mem_buf,
142+
hdl.signal_pad_ptrs_dev,
143+
TP_RANK=hdl.rank,
144+
TP_SIZE=hdl.world_size,
145+
GROUP_NAME=group_name,
146+
)
147+
148+
149+
def reference_two_dim_parallel_matmul(
150+
a_local: torch.Tensor,
151+
b_local: torch.Tensor,
152+
tp_group: dist.ProcessGroup,
153+
) -> torch.Tensor:
154+
"""
155+
Reference: all-gather K shards within the TP group, then do a local matmul.
156+
"""
157+
tp_size = dist.get_world_size(tp_group)
158+
159+
# All-gather the K-sharded a_local across the TP group → [M/SP, K].
160+
a_parts = [torch.empty_like(a_local) for _ in range(tp_size)]
161+
dist.all_gather(a_parts, a_local, group=tp_group)
162+
a_full = torch.cat(a_parts, dim=1)
163+
164+
# All-gather the K-sharded b_local across the TP group → [K, N].
165+
b_parts = [torch.empty_like(b_local) for _ in range(tp_size)]
166+
dist.all_gather(b_parts, b_local, group=tp_group)
167+
b_full = torch.cat(b_parts, dim=0)
168+
169+
return torch.mm(a_full.float(), b_full.float()).to(a_local.dtype)
170+
171+
172+
def test(
173+
M: int,
174+
K: int,
175+
N: int,
176+
device: torch.device,
177+
dtype: torch.dtype,
178+
tp_group: dist.ProcessGroup,
179+
sp_rank: int,
180+
tp_size: int,
181+
sp_size: int,
182+
) -> None:
183+
"""Test the 2D-parallel kernel against the reference."""
184+
tp_rank = dist.get_rank(tp_group)
185+
torch.manual_seed(42 + sp_rank * tp_size + tp_rank)
186+
187+
M_local = M // sp_size
188+
K_local = K // tp_size
189+
190+
# Each (sp, tp) rank owns a unique [M/SP, K/TP] block of A.
191+
a_local = torch.randn(M_local, K_local, dtype=dtype, device=device)
192+
193+
# Each TP rank owns its K-shard of B with the full N dimension.
194+
# Ranks in the same TP group hold different K-shards; ranks in different SP
195+
# groups but the same TP rank hold the *same* K-shard of B.
196+
torch.manual_seed(42 + tp_rank)
197+
b_local = torch.randn(K_local, N, dtype=dtype, device=device)
198+
199+
symm_mem_buf = symm_mem.empty(
200+
M_local, N, dtype=a_local.dtype, device=a_local.device
201+
)
202+
symm_mem.rendezvous(symm_mem_buf, tp_group.group_name)
203+
204+
_helion_two_dim_parallel_matmul = functools.partial(
205+
helion_two_dim_parallel_matmul, symm_mem_buf=symm_mem_buf
206+
)
207+
208+
run_example(
209+
lambda a, b: _helion_two_dim_parallel_matmul(a, b, tp_group),
210+
lambda a, b: reference_two_dim_parallel_matmul(a, b, tp_group),
211+
(a_local, b_local),
212+
rtol=2e-1,
213+
atol=2e-1,
214+
process_group_name=tp_group.group_name,
215+
)
216+
217+
218+
def main() -> None:
219+
"""
220+
Initialize a 2-D device mesh, enable symmetric memory for the TP groups,
221+
and run the 2D-parallel matmul test.
222+
"""
223+
_SymmetricMemory.signal_pad_size = 1024 * 1024 * 16
224+
225+
rank = int(os.environ["LOCAL_RANK"])
226+
world_size = int(os.environ["WORLD_SIZE"])
227+
assert world_size >= 4 and world_size % 2 == 0, (
228+
f"Requires at least 4 GPUs arranged as SP×TP mesh (got {world_size})"
229+
)
230+
231+
torch.manual_seed(42 + rank)
232+
device = torch.device(f"cuda:{rank}")
233+
torch.cuda.set_device(device)
234+
dist.init_process_group("nccl")
235+
236+
# Build a 2-D device mesh: SP (outer dim, row parallelism) × TP (inner dim,
237+
# K-reduction parallelism).
238+
tp_size = 2
239+
sp_size = world_size // tp_size
240+
mesh = init_device_mesh(
241+
"cuda",
242+
(sp_size, tp_size),
243+
mesh_dim_names=("sp", "tp"),
244+
)
245+
tp_group = mesh.get_group("tp")
246+
sp_rank = rank // tp_size
247+
248+
test(
249+
M=1024,
250+
K=256,
251+
N=512,
252+
device=device,
253+
dtype=torch.float32,
254+
tp_group=tp_group,
255+
sp_rank=sp_rank,
256+
tp_size=tp_size,
257+
sp_size=sp_size,
258+
)
259+
260+
dist.destroy_process_group()
261+
262+
263+
if __name__ == "__main__":
264+
"""
265+
Run with:
266+
python -m torch.distributed.run --standalone \\
267+
--nproc-per-node 4 \\
268+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \\
269+
examples/distributed/two_dim_parallel_matmul.py
270+
"""
271+
assert DEVICE.type == "cuda", "Requires CUDA device"
272+
main()

helion/_compiler/compile_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
self.index_dtype: torch.dtype = (
125125
index_dtype or settings.index_dtype or torch.int32
126126
)
127+
self.process_group_name = None
127128
backend_factory: dict[str, type[Backend]] = {
128129
"triton": TritonBackend,
129130
"pallas": PallasBackend,

0 commit comments

Comments
 (0)