|
| 1 | +""" |
| 2 | +Epilogue Subtiling Example |
| 3 | +========================== |
| 4 | +
|
| 5 | +This example demonstrates matmul kernels with heavy epilogues that benefit from |
| 6 | +epilogue subtiling on Blackwell (sm_100+). Epilogue subtiling splits the store |
| 7 | +from ``[BLOCK_M, BLOCK_N]`` into ``SUBTILE_FACTOR x [BLOCK_M, BLOCK_N / SUBTILE_FACTOR]``, |
| 8 | +halving the accumulator shared-memory footprint and enabling an extra pipeline stage. |
| 9 | +""" |
| 10 | + |
| 11 | +# %% |
| 12 | +# Imports |
| 13 | +# ------- |
| 14 | + |
| 15 | +# %% |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +import helion |
| 21 | +from helion._testing import DEVICE |
| 22 | +from helion._testing import HALF_DTYPE |
| 23 | +from helion._testing import run_example |
| 24 | +import helion.language as hl |
| 25 | + |
| 26 | +# %% |
| 27 | +# Kernel 1 -- Matmul + Residual + Bias + GELU + Cast |
| 28 | +# --------------------------------------------------- |
| 29 | +# CUTLASS-style residual + bias + GELU forward epilogue with two |
| 30 | +# fp32 reads (residual, bias) fused into the output tile. |
| 31 | + |
| 32 | + |
| 33 | +# %% |
| 34 | +@helion.kernel(static_shapes=True) |
| 35 | +def matmul_bias_residual_gelu_cast( |
| 36 | + x: torch.Tensor, |
| 37 | + w: torch.Tensor, |
| 38 | + bias: torch.Tensor, |
| 39 | + residual: torch.Tensor, |
| 40 | +) -> torch.Tensor: |
| 41 | + m, k = x.size() |
| 42 | + _, n = w.size() |
| 43 | + out = torch.empty([m, n], dtype=torch.float16, device=x.device) |
| 44 | + |
| 45 | + for tile_m, tile_n in hl.tile([m, n]): |
| 46 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 47 | + for tile_k in hl.tile(k): |
| 48 | + acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n]) |
| 49 | + |
| 50 | + val = acc * 1.25 |
| 51 | + val = val + residual[tile_m, tile_n].to(torch.float32) * 0.5 |
| 52 | + val = val + bias[tile_n] |
| 53 | + val = torch.nn.functional.gelu(val) |
| 54 | + out[tile_m, tile_n] = val.to(torch.float16) |
| 55 | + |
| 56 | + return out |
| 57 | + |
| 58 | + |
| 59 | +# %% |
| 60 | +# Kernel 2 -- Matmul + Bias + GELU with Auxiliary Output |
| 61 | +# ------------------------------------------------------ |
| 62 | +# cuBLASLt / CUTLASS-style GELU+AUX forward epilogue that writes both |
| 63 | +# the pre-activation (aux) and post-GELU (out) tensors. |
| 64 | + |
| 65 | + |
| 66 | +# %% |
| 67 | +@helion.kernel(static_shapes=True) |
| 68 | +def matmul_bias_gelu_aux( |
| 69 | + x: torch.Tensor, |
| 70 | + w: torch.Tensor, |
| 71 | + bias: torch.Tensor, |
| 72 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 73 | + m, k = x.size() |
| 74 | + _, n = w.size() |
| 75 | + out = torch.empty([m, n], dtype=torch.float16, device=x.device) |
| 76 | + aux = torch.empty([m, n], dtype=torch.float16, device=x.device) |
| 77 | + |
| 78 | + for tile_m, tile_n in hl.tile([m, n]): |
| 79 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 80 | + for tile_k in hl.tile(k): |
| 81 | + acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n]) |
| 82 | + |
| 83 | + pre = acc * 1.25 |
| 84 | + pre = pre + bias[tile_n] |
| 85 | + aux[tile_m, tile_n] = pre.to(torch.float16) |
| 86 | + out[tile_m, tile_n] = torch.nn.functional.gelu(pre).to(torch.float16) |
| 87 | + |
| 88 | + return out, aux |
| 89 | + |
| 90 | + |
| 91 | +# %% |
| 92 | +# Verification |
| 93 | +# ------------ |
| 94 | + |
| 95 | + |
| 96 | +# %% |
| 97 | +def check(m: int, k: int, n: int) -> None: |
| 98 | + x = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE) |
| 99 | + w = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE) |
| 100 | + bias = torch.randn([n], device=DEVICE, dtype=HALF_DTYPE) |
| 101 | + residual = torch.randn([m, n], device=DEVICE, dtype=HALF_DTYPE) |
| 102 | + |
| 103 | + def baseline_residual_gelu( |
| 104 | + x: torch.Tensor, |
| 105 | + w: torch.Tensor, |
| 106 | + bias: torch.Tensor, |
| 107 | + residual: torch.Tensor, |
| 108 | + ) -> torch.Tensor: |
| 109 | + acc = x.float() @ w.float() |
| 110 | + val = acc * 1.25 + residual.float() * 0.5 + bias.float() |
| 111 | + return torch.nn.functional.gelu(val).half() |
| 112 | + |
| 113 | + run_example( |
| 114 | + matmul_bias_residual_gelu_cast, |
| 115 | + baseline_residual_gelu, |
| 116 | + (x, w, bias, residual), |
| 117 | + ) |
| 118 | + |
| 119 | + def baseline_gelu_aux( |
| 120 | + x: torch.Tensor, |
| 121 | + w: torch.Tensor, |
| 122 | + bias: torch.Tensor, |
| 123 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 124 | + acc = x.float() @ w.float() |
| 125 | + pre = acc * 1.25 + bias.float() |
| 126 | + return torch.nn.functional.gelu(pre).half(), pre.half() |
| 127 | + |
| 128 | + run_example( |
| 129 | + matmul_bias_gelu_aux, |
| 130 | + baseline_gelu_aux, # pyrefly: ignore[bad-argument-type] |
| 131 | + (x, w, bias), |
| 132 | + ) |
| 133 | + |
| 134 | + |
| 135 | +# %% |
| 136 | +# Main |
| 137 | +# ---- |
| 138 | + |
| 139 | + |
| 140 | +# %% |
| 141 | +def main() -> None: |
| 142 | + check(8192, 8192, 8192) |
| 143 | + |
| 144 | + |
| 145 | +if __name__ == "__main__": |
| 146 | + main() |
0 commit comments