Skip to content

Commit 546ff56

Browse files
committed
scaled matmul reduce scatter
stack-info: PR: #1842, branch: shunting314/stack/24
1 parent 9b46217 commit 546ff56

File tree

2 files changed

+302
-0
lines changed

2 files changed

+302
-0
lines changed
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
"""
2+
FP8 MatMul + Reduce-Scatter Fusion Example
3+
==========================================
4+
This example extends the matmul_reduce_scatter example to use FP8 inputs.
5+
Each rank holds FP8 A and B shards; the kernel computes a local FP8 GEMM
6+
(accumulating in FP32 via ``hl.dot``), applies per-row/per-column scales,
7+
writes the bfloat16 partial result to a symmetric-memory buffer, performs an
8+
intra-group barrier, and then reduce-scatters: each rank accumulates the rows
9+
it owns from all peers' buffers, producing a ``[M//WORLD_SIZE, N]`` bfloat16
10+
output.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import functools
16+
import os
17+
18+
import torch
19+
from torch._C._distributed_c10d import _SymmetricMemory
20+
import torch.distributed as dist
21+
import torch.distributed._symmetric_memory as symm_mem
22+
23+
import helion
24+
from helion._testing import assert_close_with_mismatch_tolerance
25+
from helion._testing import DEVICE
26+
from helion._testing import run_example
27+
import helion.language as hl
28+
from helion.runtime.dist_utils import symm_mem_sync
29+
30+
tolerance = dict(
31+
atol=5e-1,
32+
rtol=5e-1,
33+
max_mismatch_pct=1e-3,
34+
)
35+
36+
37+
@helion.kernel(
38+
config=helion.Config(
39+
block_sizes=[64, 64, 32], # M, N, K
40+
num_warps=8,
41+
num_stages=3,
42+
),
43+
static_shapes=True,
44+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
45+
autotune_baseline_accuracy_check_fn=functools.partial(
46+
assert_close_with_mismatch_tolerance,
47+
**tolerance,
48+
),
49+
)
50+
def fp8_matmul_reduce_scatter_kernel(
51+
a: torch.Tensor, # [M, K] float8_e4m3fn
52+
b: torch.Tensor, # [K, N] float8_e4m3fn
53+
scale_a: torch.Tensor, # [M, 1] float32
54+
scale_b: torch.Tensor, # [1, N] float32
55+
symm_mem_buffer: torch.Tensor, # [M, N] bfloat16, symmetric memory
56+
signal_pad_ptrs: torch.Tensor,
57+
RANK: hl.constexpr,
58+
WORLD_SIZE: hl.constexpr,
59+
GROUP_NAME: hl.ProcessGroupName,
60+
) -> torch.Tensor:
61+
"""
62+
Fused FP8 MatMul + Reduce-Scatter kernel.
63+
64+
Computes ``(scale_a * scale_b * (A @ B)).to(bfloat16)`` in a distributed
65+
reduce-scatter pattern: each rank emits only its ``M // WORLD_SIZE`` output rows.
66+
"""
67+
M, K = a.size()
68+
K2, N = b.size()
69+
M_scatter = M // WORLD_SIZE # type: ignore[unsupported-operation]
70+
71+
output = torch.empty([M_scatter, N], dtype=torch.bfloat16, device=a.device)
72+
73+
buffer_tuple = torch.ops.symm_mem.get_remote_tensors(symm_mem_buffer, GROUP_NAME)
74+
75+
scatter_start = RANK * M_scatter # type: ignore[unsupported-operation]
76+
scatter_end = scatter_start + M_scatter # type: ignore[unsupported-operation]
77+
78+
for tile_m, tile_n in hl.tile([M, N]):
79+
# FP8 GEMM tile, accumulating in FP32
80+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
81+
for tile_k in hl.tile(K):
82+
acc = hl.dot(a[tile_m, tile_k], b[tile_k, tile_n], acc=acc)
83+
84+
# Apply per-row and per-column scales
85+
acc = acc * scale_a[tile_m, :].to(torch.float32) * scale_b[:, tile_n].to(
86+
torch.float32
87+
)
88+
89+
# Store bfloat16 partial result to this rank's symmetric-memory buffer
90+
symm_mem_buffer[tile_m, tile_n] = acc.to(torch.bfloat16)
91+
92+
# Barrier: release our write, acquire peers' writes
93+
hl.triton_kernel(
94+
symm_mem_sync,
95+
args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, True),
96+
output_like=None,
97+
)
98+
99+
# Reduce-scatter: accumulate only the rows this rank owns
100+
if tile_m.begin >= scatter_start and tile_m.begin < scatter_end: # type: ignore[unsupported-operation]
101+
acc_reduce = hl.zeros([tile_m, tile_n], dtype=torch.float32)
102+
for remote_buffer in buffer_tuple:
103+
acc_reduce = acc_reduce + remote_buffer[tile_m, tile_n].to(
104+
torch.float32
105+
)
106+
output[tile_m.index - scatter_start, tile_n] = acc_reduce.to(torch.bfloat16) # type: ignore[unsupported-operation]
107+
108+
# Final barrier (release only)
109+
hl.triton_kernel(
110+
symm_mem_sync,
111+
args=(signal_pad_ptrs, None, RANK, WORLD_SIZE, True, False),
112+
output_like=None,
113+
)
114+
115+
return output
116+
117+
118+
def helion_fp8_matmul_reduce_scatter(
119+
symm_mem_buffer: torch.Tensor,
120+
a: torch.Tensor,
121+
b: torch.Tensor,
122+
scale_a: torch.Tensor,
123+
scale_b: torch.Tensor,
124+
) -> torch.Tensor:
125+
"""
126+
Wrapper that rendezvouss on the pre-allocated symmetric buffer and
127+
invokes the FP8 reduce-scatter kernel.
128+
129+
Args:
130+
symm_mem_buffer: Pre-allocated symmetric-memory buffer ``[M, N]`` bfloat16.
131+
a: Local FP8 A shard ``[M, K]`` (``torch.float8_e4m3fn``).
132+
b: Local FP8 B shard ``[K, N]`` (``torch.float8_e4m3fn``).
133+
scale_a: Per-row scale ``[M, 1]`` float32.
134+
scale_b: Per-column scale ``[1, N]`` float32.
135+
"""
136+
group = dist.group.WORLD
137+
if group is None:
138+
raise RuntimeError("Distributed group is not initialized")
139+
140+
symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, group.group_name)
141+
142+
return fp8_matmul_reduce_scatter_kernel(
143+
a,
144+
b,
145+
scale_a,
146+
scale_b,
147+
symm_mem_buffer,
148+
symm_mem_hdl.signal_pad_ptrs_dev,
149+
RANK=symm_mem_hdl.rank,
150+
WORLD_SIZE=symm_mem_hdl.world_size,
151+
GROUP_NAME=group.group_name,
152+
)
153+
154+
155+
def reference_fp8_matmul_reduce_scatter(
156+
a: torch.Tensor,
157+
b: torch.Tensor,
158+
scale_a: torch.Tensor,
159+
scale_b: torch.Tensor,
160+
) -> torch.Tensor:
161+
"""
162+
Reference: FP8 scaled matmul, reduce-scatter along M.
163+
"""
164+
group = dist.group.WORLD
165+
if group is None:
166+
raise RuntimeError("Distributed group is not initialized")
167+
168+
c = torch._scaled_mm(a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
169+
170+
world_size = dist.get_world_size(group)
171+
M_scatter = c.shape[0] // world_size
172+
output = torch.empty(M_scatter, c.shape[1], dtype=c.dtype, device=c.device)
173+
dist.reduce_scatter_tensor(output, c, group=group)
174+
return output
175+
176+
177+
def reference_fused_scaled_matmul_reduce_scatter(
178+
a: torch.Tensor,
179+
b: torch.Tensor,
180+
scale_a: torch.Tensor,
181+
scale_b: torch.Tensor,
182+
) -> torch.Tensor:
183+
"""
184+
Reference using PyTorch's built-in
185+
``_fused_scaled_matmul_reduce_scatter`` kernel.
186+
"""
187+
group = dist.group.WORLD
188+
if group is None:
189+
raise RuntimeError("Distributed group is not initialized")
190+
191+
M, N = a.shape[0], b.shape[1]
192+
return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
193+
a,
194+
b,
195+
scale_a,
196+
scale_b,
197+
reduce_op="sum",
198+
orig_scatter_dim=0,
199+
scatter_dim_after_maybe_reshape=0,
200+
group_name=group.group_name,
201+
output_shape=[M, N],
202+
out_dtype=torch.bfloat16,
203+
)
204+
205+
206+
def test(M: int, N: int, K: int, device: torch.device) -> None:
207+
"""Test the FP8 reduce-scatter kernel against the reference."""
208+
rank = dist.get_rank()
209+
210+
torch.manual_seed(23 + rank)
211+
a_fp32 = torch.randn(M, K, device=device)
212+
a = a_fp32.to(torch.float8_e4m3fn)
213+
214+
torch.manual_seed(23)
215+
b_fp32 = torch.randn(K, N, device=device)
216+
b = b_fp32.to(torch.float8_e4m3fn).t().contiguous().t()
217+
218+
scale_a = torch.rand(M, 1, device=device)
219+
scale_b = torch.rand(1, N, device=device)
220+
221+
symm_mem_buffer = symm_mem.empty(M, N, dtype=torch.bfloat16, device=device)
222+
symm_mem.rendezvous(symm_mem_buffer, dist.group.WORLD.group_name) # type: ignore[union-attr]
223+
224+
run_example(
225+
functools.partial(helion_fp8_matmul_reduce_scatter, symm_mem_buffer),
226+
{"nccl+cublas": reference_fp8_matmul_reduce_scatter, "fused_baseline": reference_fused_scaled_matmul_reduce_scatter},
227+
(a, b, scale_a, scale_b),
228+
**tolerance,
229+
)
230+
231+
232+
def main() -> None:
233+
_SymmetricMemory.signal_pad_size = 1024 * 1024 * 16
234+
rank = int(os.environ["LOCAL_RANK"])
235+
torch.manual_seed(42 + rank)
236+
device = torch.device(f"cuda:{rank}")
237+
torch.cuda.set_device(device)
238+
dist.init_process_group("nccl")
239+
240+
test(M=512, N=768, K=1024, device=device)
241+
242+
dist.destroy_process_group()
243+
244+
245+
if __name__ == "__main__":
246+
"""
247+
Run with:
248+
python -m torch.distributed.run --standalone \\
249+
--nproc-per-node 4 \\
250+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \\
251+
examples/distributed/fp8_matmul_reduce_scatter.py
252+
"""
253+
assert DEVICE.type == "cuda", "Requires CUDA device"
254+
main()

test/test_distributed.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,54 @@ def do_test_matmul_reduce_scatter(self, kernel, ref_kernel):
426426

427427
torch.testing.assert_close(result, expected, rtol=1e-1, atol=1e-1)
428428

429+
@skipIfRocm("Distributed example requires CUDA/NCCL")
430+
@skipIfXPU("Distributed operations require CCL, not yet fully integrated")
431+
@skip_if_lt_x_gpu(4)
432+
def test_fp8_matmul_reduce_scatter(self):
433+
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
434+
self.skipTest("FP8 requires CUDA compute capability >= 9.0")
435+
self._init_process()
436+
437+
mod = import_path(EXAMPLES_DIR / "distributed" / "fp8_matmul_reduce_scatter.py")
438+
439+
_SymmetricMemory.signal_pad_size = 1024 * 1024 * 16
440+
M, N, K = 512, 768, 1024
441+
442+
torch.manual_seed(42 + self.rank)
443+
a = torch.randn(M, K, device=self.device).to(torch.float8_e4m3fn)
444+
445+
torch.manual_seed(42)
446+
b = (
447+
torch.randn(K, N, device=self.device)
448+
.to(torch.float8_e4m3fn)
449+
.t()
450+
.contiguous()
451+
.t()
452+
)
453+
454+
scale_a = torch.rand(M, 1, device=self.device)
455+
scale_b = torch.rand(1, N, device=self.device)
456+
457+
symm_mem_buffer = symm_mem.empty(M, N, dtype=torch.bfloat16, device=self.device)
458+
symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, dist.group.WORLD.group_name)
459+
460+
result = mod.fp8_matmul_reduce_scatter_kernel(
461+
a,
462+
b,
463+
scale_a,
464+
scale_b,
465+
symm_mem_buffer,
466+
symm_mem_hdl.signal_pad_ptrs_dev,
467+
RANK=symm_mem_hdl.rank,
468+
WORLD_SIZE=symm_mem_hdl.world_size,
469+
GROUP_NAME=dist.group.WORLD.group_name,
470+
)
471+
472+
expected = mod.reference_fp8_matmul_reduce_scatter(a, b, scale_a, scale_b)
473+
474+
torch.testing.assert_close(result, expected, rtol=8e-1, atol=8e-1)
475+
self._cleanup_process()
476+
429477
@skipIfRocm("Distributed example requires CUDA/NCCL")
430478
@skipIfXPU("Distributed operations require CCL, not yet fully integrated")
431479
@skip_if_lt_x_gpu(4)

0 commit comments

Comments
 (0)