Skip to content

Commit bafff2a

Browse files
committed
Add epilogue subtiling example and tuple output support in run_example
1 parent 8823a17 commit bafff2a

File tree

3 files changed

+164
-16
lines changed

3 files changed

+164
-16
lines changed

examples/epilogue_subtiling.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""
2+
Epilogue Subtiling Example
3+
==========================
4+
5+
This example demonstrates matmul kernels with heavy epilogues that benefit from
6+
epilogue subtiling on Blackwell (sm_100+). Epilogue subtiling splits the store
7+
from ``[BLOCK_M, BLOCK_N]`` into ``SUBTILE_FACTOR x [BLOCK_M, BLOCK_N / SUBTILE_FACTOR]``,
8+
halving the accumulator shared-memory footprint and enabling an extra pipeline stage.
9+
"""
10+
11+
# %%
12+
# Imports
13+
# -------
14+
15+
# %%
16+
from __future__ import annotations
17+
18+
import torch
19+
20+
import helion
21+
from helion._testing import DEVICE
22+
from helion._testing import HALF_DTYPE
23+
from helion._testing import run_example
24+
import helion.language as hl
25+
26+
# %%
27+
# Kernel 1 -- Matmul + Residual + Bias + GELU + Cast
28+
# ---------------------------------------------------
29+
# CUTLASS-style residual + bias + GELU forward epilogue with two
30+
# fp32 reads (residual, bias) fused into the output tile.
31+
32+
33+
# %%
34+
@helion.kernel(static_shapes=True)
35+
def matmul_bias_residual_gelu_cast(
36+
x: torch.Tensor,
37+
w: torch.Tensor,
38+
bias: torch.Tensor,
39+
residual: torch.Tensor,
40+
) -> torch.Tensor:
41+
m, k = x.size()
42+
_, n = w.size()
43+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
44+
45+
for tile_m, tile_n in hl.tile([m, n]):
46+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
47+
for tile_k in hl.tile(k):
48+
acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n])
49+
50+
val = acc * 1.25
51+
val = val + residual[tile_m, tile_n].to(torch.float32) * 0.5
52+
val = val + bias[tile_n]
53+
val = torch.nn.functional.gelu(val)
54+
out[tile_m, tile_n] = val.to(torch.float16)
55+
56+
return out
57+
58+
59+
# %%
60+
# Kernel 2 -- Matmul + Bias + GELU with Auxiliary Output
61+
# ------------------------------------------------------
62+
# cuBLASLt / CUTLASS-style GELU+AUX forward epilogue that writes both
63+
# the pre-activation (aux) and post-GELU (out) tensors.
64+
65+
66+
# %%
67+
@helion.kernel(static_shapes=True)
68+
def matmul_bias_gelu_aux(
69+
x: torch.Tensor,
70+
w: torch.Tensor,
71+
bias: torch.Tensor,
72+
) -> tuple[torch.Tensor, torch.Tensor]:
73+
m, k = x.size()
74+
_, n = w.size()
75+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
76+
aux = torch.empty([m, n], dtype=torch.float16, device=x.device)
77+
78+
for tile_m, tile_n in hl.tile([m, n]):
79+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
80+
for tile_k in hl.tile(k):
81+
acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n])
82+
83+
pre = acc * 1.25
84+
pre = pre + bias[tile_n]
85+
aux[tile_m, tile_n] = pre.to(torch.float16)
86+
out[tile_m, tile_n] = torch.nn.functional.gelu(pre).to(torch.float16)
87+
88+
return out, aux
89+
90+
91+
# %%
92+
# Verification
93+
# ------------
94+
95+
96+
# %%
97+
def check(m: int, k: int, n: int) -> None:
98+
x = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE)
99+
w = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE)
100+
bias = torch.randn([n], device=DEVICE, dtype=HALF_DTYPE)
101+
residual = torch.randn([m, n], device=DEVICE, dtype=HALF_DTYPE)
102+
103+
def baseline_residual_gelu(
104+
x: torch.Tensor,
105+
w: torch.Tensor,
106+
bias: torch.Tensor,
107+
residual: torch.Tensor,
108+
) -> torch.Tensor:
109+
acc = x.float() @ w.float()
110+
val = acc * 1.25 + residual.float() * 0.5 + bias.float()
111+
return torch.nn.functional.gelu(val).half()
112+
113+
run_example(
114+
matmul_bias_residual_gelu_cast,
115+
baseline_residual_gelu,
116+
(x, w, bias, residual),
117+
)
118+
119+
def baseline_gelu_aux(
120+
x: torch.Tensor,
121+
w: torch.Tensor,
122+
bias: torch.Tensor,
123+
) -> tuple[torch.Tensor, torch.Tensor]:
124+
acc = x.float() @ w.float()
125+
pre = acc * 1.25 + bias.float()
126+
return torch.nn.functional.gelu(pre).half(), pre.half()
127+
128+
run_example(
129+
matmul_bias_gelu_aux,
130+
baseline_gelu_aux, # pyrefly: ignore[bad-argument-type]
131+
(x, w, bias),
132+
)
133+
134+
135+
# %%
136+
# Main
137+
# ----
138+
139+
140+
# %%
141+
def main() -> None:
142+
check(8192, 8192, 8192)
143+
144+
145+
if __name__ == "__main__":
146+
main()

helion/_compiler/device_function.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,6 @@ def allocate_store_index(self) -> int:
325325
self.device_memory_op_index += 1
326326
return idx
327327

328-
def allocate_store_index(self) -> int:
329-
"""Bump store counters and return the indexing strategy slot."""
330-
self.device_store_index += 1
331-
idx = self.device_memory_op_index
332-
self.device_memory_op_index += 1
333-
return idx
334-
335328
def get_indexing_strategy(self, index: int) -> IndexingStrategy:
336329
from .indexing_strategy import IndexingStrategy
337330
from .indexing_strategy import PointerIndexingStrategy

helion/_testing.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,14 @@ def code_and_output(
937937
return code, result
938938

939939

940+
def _as_tensors(result: object) -> list[torch.Tensor]:
941+
"""Normalize a single tensor or tuple of tensors to a flat list."""
942+
if isinstance(result, tuple):
943+
return [t.clone() for t in result]
944+
assert isinstance(result, torch.Tensor)
945+
return [result.clone()]
946+
947+
940948
def run_example(
941949
kernel_fn: Callable[..., torch.Tensor] | Kernel | dict[str, Kernel],
942950
baseline_fn: Callable[..., torch.Tensor] | dict[str, Callable[..., torch.Tensor]],
@@ -975,20 +983,21 @@ def run_example(
975983

976984
# Check correctness against first baseline
977985
first_baseline_name, first_baseline_func = next(iter(baselines.items()))
978-
expected = first_baseline_func(*args).clone()
986+
expected = _as_tensors(first_baseline_func(*args))
979987

980988
for name, func in {**kernels, **baselines}.items():
981989
if name != first_baseline_name:
982990
print(f"Testing {name} correctness...", file=sys.stderr)
983-
# Clone args to avoid buffer donation issues (e.g., Pallas/TPU)
984991
cloned_args = _clone_args(args)
985-
result = func(*cloned_args).clone()
986-
torch.testing.assert_close(
987-
result.to(torch.float32),
988-
expected.to(torch.float32),
989-
rtol=rtol,
990-
atol=atol,
991-
)
992+
result = _as_tensors(func(*cloned_args))
993+
assert len(result) == len(expected)
994+
for r, e in zip(result, expected, strict=True):
995+
torch.testing.assert_close(
996+
r.to(torch.float32),
997+
e.to(torch.float32),
998+
rtol=rtol,
999+
atol=atol,
1000+
)
9921001

9931002
# Test backward pass
9941003
if bwd:

0 commit comments

Comments
 (0)