-
Notifications
You must be signed in to change notification settings - Fork 137
Expand file tree
/
Copy pathepilogue_subtiling.py
More file actions
146 lines (114 loc) · 3.89 KB
/
epilogue_subtiling.py
File metadata and controls
146 lines (114 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Epilogue Subtiling Example
==========================
This example demonstrates matmul kernels with heavy epilogues that benefit from
epilogue subtiling on Blackwell (sm_100+). Epilogue subtiling splits the store
from ``[BLOCK_M, BLOCK_N]`` into ``SUBTILE_FACTOR x [BLOCK_M, BLOCK_N / SUBTILE_FACTOR]``,
halving the accumulator shared-memory footprint and enabling an extra pipeline stage.
"""
# %%
# Imports
# -------
# %%
from __future__ import annotations
import torch
import helion
from helion._testing import DEVICE
from helion._testing import HALF_DTYPE
from helion._testing import run_example
import helion.language as hl
# %%
# Kernel 1 -- Matmul + Residual + Bias + GELU + Cast
# ---------------------------------------------------
# CUTLASS-style residual + bias + GELU forward epilogue with two
# fp32 reads (residual, bias) fused into the output tile.
# %%
@helion.kernel(static_shapes=True)
def matmul_bias_residual_gelu_cast(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
m, k = x.size()
_, n = w.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n])
val = acc * 1.25
val = val + residual[tile_m, tile_n].to(torch.float32) * 0.5
val = val + bias[tile_n]
val = torch.nn.functional.gelu(val)
out[tile_m, tile_n] = val.to(torch.float16)
return out
# %%
# Kernel 2 -- Matmul + Bias + GELU with Auxiliary Output
# ------------------------------------------------------
# cuBLASLt / CUTLASS-style GELU+AUX forward epilogue that writes both
# the pre-activation (aux) and post-GELU (out) tensors.
# %%
@helion.kernel(static_shapes=True)
def matmul_bias_gelu_aux(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
m, k = x.size()
_, n = w.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
aux = torch.empty([m, n], dtype=torch.float16, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n])
pre = acc * 1.25
pre = pre + bias[tile_n]
aux[tile_m, tile_n] = pre.to(torch.float16)
out[tile_m, tile_n] = torch.nn.functional.gelu(pre).to(torch.float16)
return out, aux
# %%
# Verification
# ------------
# %%
def check(m: int, k: int, n: int) -> None:
x = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE)
w = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE)
bias = torch.randn([n], device=DEVICE, dtype=HALF_DTYPE)
residual = torch.randn([m, n], device=DEVICE, dtype=HALF_DTYPE)
def baseline_residual_gelu(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
acc = x.float() @ w.float()
val = acc * 1.25 + residual.float() * 0.5 + bias.float()
return torch.nn.functional.gelu(val).half()
run_example(
matmul_bias_residual_gelu_cast,
baseline_residual_gelu,
(x, w, bias, residual),
)
def baseline_gelu_aux(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
acc = x.float() @ w.float()
pre = acc * 1.25 + bias.float()
return torch.nn.functional.gelu(pre).half(), pre.half()
run_example(
matmul_bias_gelu_aux,
baseline_gelu_aux, # pyrefly: ignore[bad-argument-type]
(x, w, bias),
)
# %%
# Main
# ----
# %%
def main() -> None:
check(8192, 8192, 8192)
if __name__ == "__main__":
main()