Skip to content

Commit 1d684af

Browse files
committed
scaled matmul reduce scatter
1 parent 39e0be8 commit 1d684af

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
FP8 MatMul + Reduce-Scatter Fusion Example
3+
==========================================
4+
This example extends the matmul_reduce_scatter example to use FP8 inputs.
5+
Each rank holds FP8 A and B shards; the kernel computes a local FP8 GEMM
6+
(accumulating in FP32 via ``hl.dot``), writes the float16 partial result to
7+
a symmetric-memory buffer, performs an intra-group barrier, and then
8+
reduce-scatters: each rank accumulates the rows it owns from all peers'
9+
buffers, producing a ``[M//WORLD_SIZE, N]`` float16 output.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import functools
15+
import os
16+
17+
import torch
18+
from torch._C._distributed_c10d import _SymmetricMemory
19+
import torch.distributed as dist
20+
import torch.distributed._symmetric_memory as symm_mem
21+
22+
import helion
23+
from helion._testing import DEVICE
24+
from helion._testing import run_example
25+
import helion.language as hl
26+
from helion.runtime.dist_utils import symm_mem_sync
27+
28+
29+
@helion.kernel(
30+
config=helion.Config(
31+
block_sizes=[64, 64, 32], # M, N, K
32+
num_warps=8,
33+
num_stages=3,
34+
),
35+
static_shapes=True,
36+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
37+
)
38+
def fp8_matmul_reduce_scatter_kernel(
39+
a: torch.Tensor, # [M, K] float8_e4m3fn
40+
b: torch.Tensor, # [K, N] float8_e4m3fn
41+
symm_mem_buffer: torch.Tensor, # [M, N] float16, symmetric memory
42+
signal_pad_ptrs: torch.Tensor,
43+
RANK: hl.constexpr,
44+
WORLD_SIZE: hl.constexpr,
45+
GROUP_NAME: hl.ProcessGroupName,
46+
) -> torch.Tensor:
47+
"""
48+
Fused FP8 MatMul + Reduce-Scatter kernel.
49+
50+
Computes ``(A @ B).to(float16)`` in a distributed reduce-scatter pattern:
51+
each rank emits only its ``M // WORLD_SIZE`` output rows.
52+
"""
53+
M, K = a.size()
54+
K2, N = b.size()
55+
M_scatter = M // WORLD_SIZE # type: ignore[unsupported-operation]
56+
57+
output = torch.empty([M_scatter, N], dtype=torch.float16, device=a.device)
58+
59+
buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, GROUP_NAME)
60+
61+
scatter_start = RANK * M_scatter # type: ignore[unsupported-operation]
62+
scatter_end = scatter_start + M_scatter # type: ignore[unsupported-operation]
63+
64+
for tile_m, tile_n in hl.tile([M, N]):
65+
# FP8 GEMM tile, accumulating in FP32
66+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
67+
for tile_k in hl.tile(K):
68+
acc = hl.dot(a[tile_m, tile_k], b[tile_k, tile_n], acc=acc)
69+
70+
# Store float16 partial result to this rank's symmetric-memory buffer
71+
symm_mem_buffer[tile_m, tile_n] = acc.to(torch.float16)
72+
73+
# Barrier: release our write, acquire peers' writes
74+
hl.triton_kernel(
75+
symm_mem_sync,
76+
args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True),
77+
output_like=None,
78+
)
79+
80+
# Reduce-scatter: accumulate only the rows this rank owns
81+
if tile_m.begin >= scatter_start and tile_m.begin < scatter_end: # type: ignore[unsupported-operation]
82+
acc_reduce = hl.zeros([tile_m, tile_n], dtype=torch.float32)
83+
for remote_buffer in buffer_tuple:
84+
acc_reduce = acc_reduce + remote_buffer[tile_m, tile_n].to(
85+
torch.float32
86+
)
87+
output[tile_m.index - scatter_start, tile_n] = acc_reduce.to(torch.float16) # type: ignore[unsupported-operation]
88+
89+
# Final barrier (release only)
90+
hl.triton_kernel(
91+
symm_mem_sync,
92+
args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, False),
93+
output_like=None,
94+
)
95+
96+
return output
97+
98+
99+
def helion_fp8_matmul_reduce_scatter(
100+
symm_mem_buffer: torch.Tensor,
101+
a: torch.Tensor,
102+
b: torch.Tensor,
103+
) -> torch.Tensor:
104+
"""
105+
Wrapper that rendezvouss on the pre-allocated symmetric buffer and
106+
invokes the FP8 reduce-scatter kernel.
107+
108+
Args:
109+
symm_mem_buffer: Pre-allocated symmetric-memory buffer ``[M, N]`` float16.
110+
a: Local FP8 A shard ``[M, K]`` (``torch.float8_e4m3fn``).
111+
b: Local FP8 B shard ``[K, N]`` (``torch.float8_e4m3fn``).
112+
"""
113+
group = dist.group.WORLD
114+
if group is None:
115+
raise RuntimeError("Distributed group is not initialized")
116+
117+
symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, group.group_name)
118+
119+
return fp8_matmul_reduce_scatter_kernel(
120+
a,
121+
b,
122+
symm_mem_buffer,
123+
symm_mem_hdl.signal_pad_ptrs_dev,
124+
RANK=symm_mem_hdl.rank,
125+
WORLD_SIZE=symm_mem_hdl.world_size,
126+
GROUP_NAME=group.group_name,
127+
)
128+
129+
130+
def reference_fp8_matmul_reduce_scatter(
131+
a: torch.Tensor,
132+
b: torch.Tensor,
133+
) -> torch.Tensor:
134+
"""
135+
Reference: dequantize to float32, matmul, reduce-scatter along M.
136+
"""
137+
group = dist.group.WORLD
138+
if group is None:
139+
raise RuntimeError("Distributed group is not initialized")
140+
141+
c = torch.mm(a.to(torch.float32), b.to(torch.float32)).to(torch.float16)
142+
143+
world_size = dist.get_world_size(group)
144+
M_scatter = c.shape[0] // world_size
145+
output = torch.empty(M_scatter, c.shape[1], dtype=c.dtype, device=c.device)
146+
dist.reduce_scatter_tensor(output, c, group=group)
147+
return output
148+
149+
150+
def test(M: int, N: int, K: int, device: torch.device) -> None:
151+
"""Test the FP8 reduce-scatter kernel against the reference."""
152+
rank = dist.get_rank()
153+
154+
torch.manual_seed(42 + rank)
155+
a_fp32 = torch.randn(M, K, device=device)
156+
a = a_fp32.to(torch.float8_e4m3fn)
157+
158+
torch.manual_seed(42)
159+
b_fp32 = torch.randn(K, N, device=device)
160+
b = b_fp32.to(torch.float8_e4m3fn)
161+
162+
symm_mem_buffer = symm_mem.empty(M, N, dtype=torch.float16, device=device)
163+
symm_mem.rendezvous(symm_mem_buffer, dist.group.WORLD.group_name) # type: ignore[union-attr]
164+
165+
run_example(
166+
functools.partial(helion_fp8_matmul_reduce_scatter, symm_mem_buffer),
167+
reference_fp8_matmul_reduce_scatter,
168+
(a, b),
169+
rtol=2e-1,
170+
atol=2e-1,
171+
)
172+
173+
174+
def main() -> None:
175+
_SymmetricMemory.signal_pad_size = 1024 * 1024 * 16
176+
rank = int(os.environ["LOCAL_RANK"])
177+
torch.manual_seed(42 + rank)
178+
device = torch.device(f"cuda:{rank}")
179+
torch.cuda.set_device(device)
180+
dist.init_process_group("nccl")
181+
182+
test(M=512, N=768, K=1024, device=device)
183+
184+
dist.destroy_process_group()
185+
186+
187+
if __name__ == "__main__":
188+
"""
189+
Run with:
190+
python -m torch.distributed.run --standalone \\
191+
--nproc-per-node 4 \\
192+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \\
193+
examples/distributed/fp8_matmul_reduce_scatter.py
194+
"""
195+
assert DEVICE.type == "cuda", "Requires CUDA device"
196+
main()

test/test_distributed.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,45 @@ def do_test_matmul_reduce_scatter(self, kernel, ref_kernel):
426426

427427
torch.testing.assert_close(result, expected, rtol=1e-1, atol=1e-1)
428428

429+
@skipIfRocm("Distributed example requires CUDA/NCCL")
430+
@skipIfXPU("Distributed operations require CCL, not yet fully integrated")
431+
@skip_if_lt_x_gpu(4)
432+
def test_fp8_matmul_reduce_scatter(self):
433+
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
434+
self.skipTest("FP8 requires CUDA compute capability >= 9.0")
435+
self._init_process()
436+
437+
mod = import_path(EXAMPLES_DIR / "distributed" / "fp8_matmul_reduce_scatter.py")
438+
439+
_SymmetricMemory.signal_pad_size = 1024 * 1024 * 16
440+
M, N, K = 512, 768, 1024
441+
442+
torch.manual_seed(42 + self.rank)
443+
a_fp32 = torch.randn(M, K, device=self.device)
444+
a = a_fp32.to(torch.float8_e4m3fn)
445+
446+
torch.manual_seed(42)
447+
b_fp32 = torch.randn(K, N, device=self.device)
448+
b = b_fp32.to(torch.float8_e4m3fn)
449+
450+
symm_mem_buffer = symm_mem.empty(M, N, dtype=torch.float16, device=self.device)
451+
symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, dist.group.WORLD.group_name)
452+
453+
result = mod.fp8_matmul_reduce_scatter_kernel(
454+
a,
455+
b,
456+
symm_mem_buffer,
457+
symm_mem_hdl.signal_pad_ptrs_dev,
458+
RANK=symm_mem_hdl.rank,
459+
WORLD_SIZE=symm_mem_hdl.world_size,
460+
GROUP_NAME=dist.group.WORLD.group_name,
461+
)
462+
463+
expected = mod.reference_fp8_matmul_reduce_scatter(a, b)
464+
465+
torch.testing.assert_close(result, expected, rtol=2e-1, atol=2e-1)
466+
self._cleanup_process()
467+
429468
@skipIfRocm("Distributed example requires CUDA/NCCL")
430469
@skipIfXPU("Distributed operations require CCL, not yet fully integrated")
431470
@skip_if_lt_x_gpu(4)

0 commit comments

Comments
 (0)