Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/api/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,35 @@ Configs are typically discovered automatically through autotuning, but can also
``[load1, load2, ..., loadN, store1, store2, ..., storeM]``
```

### Epilogue Optimization

```{eval-rst}
.. autoattribute:: Config.epilogue_subtile

Split factor for the epilogue (pointwise ops + store) along a tile dimension.
Splits the store from ``[BLOCK_M, BLOCK_N]`` into
``SUBTILE_FACTOR × [BLOCK_M, BLOCK_N / SUBTILE_FACTOR]``, reducing the accumulator
shared-memory footprint and enabling extra pipeline stages.

**Valid values:**

- ``None``: Disabled (default)
- ``2``: Split epilogue into 2 sub-tiles
- ``4``: Split epilogue into 4 sub-tiles (when K ≥ 16384)

**Requirements:**

- Blackwell (sm_100+) GPU with tensor descriptor support
- Automatically discovered by the autotuner when the K dimension is ≥ 1024

**Interactions:**

- Incompatible with ``flatten_loops=True``
- Forces store indexing to ``"tensor_descriptor"``

See the :doc:`epilogue subtiling example </examples/epilogue_subtiling>` for usage patterns.
```

### Memory and Caching

```{eval-rst}
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ portable between different hardware. Helion automates and autotunes over:
* PID swizzling for improved L2 cache reuse.
* Loop reordering.
* Persistent kernel strategies.
* Epilogue subtiling for matmul-heavy kernels (Blackwell GPUs).
* Warp specialization choices, unrolling, and more.

## Try Helion Now
Expand Down
146 changes: 146 additions & 0 deletions examples/epilogue_subtiling.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an example file here, shouldn't the autotuner find it automatically?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The examples are just to illustrate to users that epilogue subtiling is supported and demonstrate cases where it can provide benefits.

Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Epilogue Subtiling Example
==========================

This example demonstrates matmul kernels with heavy epilogues that benefit from
epilogue subtiling on Blackwell (sm_100+). Epilogue subtiling splits the store
from ``[BLOCK_M, BLOCK_N]`` into ``SUBTILE_FACTOR x [BLOCK_M, BLOCK_N / SUBTILE_FACTOR]``,
halving the accumulator shared-memory footprint and enabling an extra pipeline stage.
"""

# %%
# Imports
# -------

# %%
from __future__ import annotations

import torch

import helion
from helion._testing import DEVICE
from helion._testing import HALF_DTYPE
from helion._testing import run_example
import helion.language as hl

# %%
# Kernel 1 -- Matmul + Residual + Bias + GELU + Cast
# ---------------------------------------------------
# CUTLASS-style residual + bias + GELU forward epilogue with two
# fp32 reads (residual, bias) fused into the output tile.


# %%
@helion.kernel(static_shapes=True)
def matmul_bias_residual_gelu_cast(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
m, k = x.size()
_, n = w.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)

for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n])

val = acc * 1.25
val = val + residual[tile_m, tile_n].to(torch.float32) * 0.5
val = val + bias[tile_n]
val = torch.nn.functional.gelu(val)
out[tile_m, tile_n] = val.to(torch.float16)

return out


# %%
# Kernel 2 -- Matmul + Bias + GELU with Auxiliary Output
# ------------------------------------------------------
# cuBLASLt / CUTLASS-style GELU+AUX forward epilogue that writes both
# the pre-activation (aux) and post-GELU (out) tensors.


# %%
@helion.kernel(static_shapes=True)
def matmul_bias_gelu_aux(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
m, k = x.size()
_, n = w.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
aux = torch.empty([m, n], dtype=torch.float16, device=x.device)

for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n])

pre = acc * 1.25
pre = pre + bias[tile_n]
aux[tile_m, tile_n] = pre.to(torch.float16)
out[tile_m, tile_n] = torch.nn.functional.gelu(pre).to(torch.float16)

return out, aux


# %%
# Verification
# ------------


# %%
def check(m: int, k: int, n: int) -> None:
x = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE)
w = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE)
bias = torch.randn([n], device=DEVICE, dtype=HALF_DTYPE)
residual = torch.randn([m, n], device=DEVICE, dtype=HALF_DTYPE)

def baseline_residual_gelu(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
acc = x.float() @ w.float()
val = acc * 1.25 + residual.float() * 0.5 + bias.float()
return torch.nn.functional.gelu(val).half()

run_example(
matmul_bias_residual_gelu_cast,
baseline_residual_gelu,
(x, w, bias, residual),
)

def baseline_gelu_aux(
x: torch.Tensor,
w: torch.Tensor,
bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
acc = x.float() @ w.float()
pre = acc * 1.25 + bias.float()
return torch.nn.functional.gelu(pre).half(), pre.half()

run_example(
matmul_bias_gelu_aux,
baseline_gelu_aux, # pyrefly: ignore[bad-argument-type]
(x, w, bias),
)


# %%
# Main
# ----


# %%
def main() -> None:
check(8192, 8192, 8192)


if __name__ == "__main__":
main()
48 changes: 34 additions & 14 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,27 +1765,35 @@ def flush_phases(self) -> None:
self.current_phase_roots = []


def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]:
def _count_device_loads_and_stores(
device_ir: DeviceIR,
) -> tuple[int, int, list[int]]:
"""Count the number of load and store operations in device code for autotuning.

Returns:
tuple[int, int, int]: (total_load_count, loads_without_eviction_policy, store_count)
tuple[int, int, list[int]]: (
total_load_count,
loads_without_eviction_policy,
store_indices,
)
- total_load_count: all loads (for indexing tunable)
- loads_without_eviction_policy: loads that need eviction policy tuning
- store_count: all stores (for indexing tunable)
- store_indices: positions of store ops in the combined indexing list
"""
from ..language import memory_ops

total_load_count = 0
loads_without_eviction_policy = 0
store_count = 0
memory_op_index = 0
store_indices: list[int] = []

for graph_info in device_ir.graphs:
for node in graph_info.graph.nodes:
if node.op == "call_function":
# Check if this is a load operation
if node.target is memory_ops.load:
total_load_count += 1
memory_op_index += 1
# Check if this load needs eviction policy tuning
# (user can still specify eviction_policy to override tuning)
eviction_policy_arg = node.kwargs.get("eviction_policy")
Expand All @@ -1797,30 +1805,38 @@ def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]:
loads_without_eviction_policy += 1
# Check if this is a store operation
elif node.target is memory_ops.store:
store_count += 1
store_indices.append(memory_op_index)
memory_op_index += 1

return total_load_count, loads_without_eviction_policy, store_count
return (
total_load_count,
loads_without_eviction_policy,
store_indices,
)


def _register_load_store_tunables(
total_load_count: int, loads_without_eviction_policy: int, store_count: int
total_load_count: int,
loads_without_eviction_policy: int,
store_indices: list[int],
) -> None:
"""Register list-based tunables (indexing, eviction policies) for all device loads and stores.

Args:
total_load_count: Total number of loads (for indexing tunable)
loads_without_eviction_policy: Number of loads that need eviction policy tuning
store_count: Total number of stores (for indexing tunable)
store_indices: Positions of store ops in the combined indexing list
"""
store_count = len(store_indices)
env = CompileEnvironment.current()
env.config_spec.store_indices = store_indices
if total_load_count == 0 and store_count == 0:
return

from ..autotuner.config_fragment import EnumFragment
from ..autotuner.config_fragment import ListOf
from ..autotuner.config_spec import get_valid_eviction_policies

env = CompileEnvironment.current()

# Register eviction policies only for loads without explicit eviction_policy
if loads_without_eviction_policy > 0:
env.config_spec.load_eviction_policies = ListOf(
Expand Down Expand Up @@ -1893,11 +1909,15 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")

# Count all device loads and stores and register tunables
total_load_count, loads_without_eviction_policy, store_count = (
_count_device_loads_and_stores(device_ir)
)
(
total_load_count,
loads_without_eviction_policy,
store_indices,
) = _count_device_loads_and_stores(device_ir)
_register_load_store_tunables(
total_load_count, loads_without_eviction_policy, store_count
total_load_count,
loads_without_eviction_policy,
store_indices,
)

return device_ir
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_tl_range_kwargs(config: Config, block_idx: int) -> list[str]:
)
num_stages = config.num_stages

if config.indexing == "tensor_descriptor":
if "tensor_descriptor" in config.indexing:
# Tensor descriptor + multi-stage pipelines in addition to unrolling tend to cause
# CUDA "misaligned address" or "unspecified launch failure" errors.
if range_num_stages > 0:
Expand Down
44 changes: 27 additions & 17 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,14 @@ def output_only(
return result


def _as_tensors(result: object) -> list[torch.Tensor]:
"""Normalize a single tensor or tuple of tensors to a flat list."""
if isinstance(result, tuple):
return [t.clone() for t in result]
assert isinstance(result, torch.Tensor)
return [result.clone()]


def run_example(
kernel_fn: Callable[..., torch.Tensor] | Kernel | dict[str, Kernel],
baseline_fn: Callable[..., torch.Tensor] | dict[str, Callable[..., torch.Tensor]],
Expand Down Expand Up @@ -1039,29 +1047,31 @@ def run_example(

# Check correctness against first baseline
first_baseline_name, first_baseline_func = next(iter(baselines.items()))
expected = first_baseline_func(*args).clone()
expected = _as_tensors(first_baseline_func(*args))

for name, func in {**kernels, **baselines}.items():
if name != first_baseline_name:
print(f"Testing {name} correctness...", file=sys.stderr)
# Clone args to avoid buffer donation issues (e.g., Pallas/TPU)
cloned_args = _clone_args(args, process_group_name=process_group_name)
result = func(*cloned_args).clone()
if max_mismatch_pct is not None:
assert_close_with_mismatch_tolerance(
result.to(torch.float32),
expected.to(torch.float32),
atol=atol,
rtol=rtol,
max_mismatch_pct=max_mismatch_pct,
)
else:
torch.testing.assert_close(
result.to(torch.float32),
expected.to(torch.float32),
rtol=rtol,
atol=atol,
)
result = _as_tensors(func(*cloned_args))
assert len(result) == len(expected)
for r, e in zip(result, expected, strict=True):
if max_mismatch_pct is not None:
assert_close_with_mismatch_tolerance(
r.to(torch.float32),
e.to(torch.float32),
atol=atol,
rtol=rtol,
max_mismatch_pct=max_mismatch_pct,
)
else:
torch.testing.assert_close(
r.to(torch.float32),
e.to(torch.float32),
rtol=rtol,
atol=atol,
)

# Test backward pass
if bwd:
Expand Down
5 changes: 4 additions & 1 deletion helion/autotuner/config_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def get_next_value(spec: ConfigSpecFragment) -> object:
advanced_controls_files=self._advanced_controls_files,
)
assert next(count) == len(flat_values)
return self._apply_overrides(config)
config = self._apply_overrides(config)
# Overrides may reintroduce pointer stores that break subtiled outputs
self.config_spec.fix_epilogue_subtile_store_indexing(config.config)
return config

def block_numel(self, flat_config: FlatConfig) -> int:
return functools.reduce(
Expand Down
Loading
Loading