Helion implementations of chunked linear attention covering the major linear attention / state-space model variants. Every example includes correctness tests (forward, backward, and recurrent step) against both a pure-PyTorch reference and flash-linear-attention (FLA) / mamba, plus benchmarks comparing Helion kernel performance against those libraries.
This work builds on ideas from the Attention Engine project.
The examples require the following optional packages for reference comparisons and benchmarks:
# flash-linear-attention (FLA) — reference for most variants
pip install fla
# mamba — reference for Mamba-2 SSD only
pip install mamba_ssmIf these are not installed, the examples will still run their core correctness tests against the pure-PyTorch reference, but FLA/mamba comparisons and benchmarks will be skipped with a warning.
| Example | Decay | Correction | FLA baseline |
|---|---|---|---|
example_simple_gla.py |
scalar | none | chunk_simple_gla |
example_full_gla.py |
diagonal | none | chunk_gla |
example_delta_rule.py |
none | rank-1 | chunk_delta_rule |
example_gated_delta_rule.py |
scalar | rank-1 | chunk_gated_delta_rule |
example_vanilla_linear_attn.py |
none | none | chunk_linear_attn |
example_retention.py |
scalar (fixed) | none | chunk_retention |
example_mamba2_ssd.py |
scalar | none | mamba_chunk_scan_combined |
example_rwkv6.py |
diagonal | none (+ output gate) | chunk_rwkv6 |
example_kda.py |
diagonal | rank-1 | chunk_kda |
linear_attention_engine.py -- All Helion kernels (16 @helion.experimental.aot_kernel()
functions including fused recurrent step), ChunkedLinearAttnFn autograd wrapper,
LinearAttentionEngine class, and the chunked_linear_attn() / recurrent_step()
public entry points.
linear_attention_utils.py -- Pure-PyTorch chunked reference implementation,
naive recurrent reference, WY decomposition helpers, input generators.
# Run all examples (test + benchmark):
python -m examples.linear.all
# Single example:
python -m examples.linear.example_simple_gla
# Run all tests via pytest:
pytest test/test_examples.py -k "test_linear" -v
# Run monkey-patch tests (plugs our engine into FLA layers):
pytest test/test_examples.py -k "monkeypatch" -vThe kernels use @helion.experimental.aot_kernel() and are ready for
ahead-of-time autotuning.
Note: HELION_AUTOTUNE_IGNORE_ERRORS=1 is required because certain
autotuner-generated configs (specifically flatten_loops=True for kernels
using .sum() reductions) trigger a Helion codegen bug. These configs are
automatically skipped when ignore-errors is enabled.
# Quick tuning of all examples:
HELION_AUTOTUNE_PRECOMPILE=spawn \
HELION_AUTOTUNE_IGNORE_ERRORS=1 \
HELION_AUTOTUNE_MAX_GENERATIONS=5 \
python -m examples.linear.all
# Full AOT pipeline via the runner:
HELION_AUTOTUNE_PRECOMPILE=spawn \
HELION_AUTOTUNE_IGNORE_ERRORS=1 \
python -m helion.experimental.aot_runner --phase all \
-- python -m examples.linear.all
-
flash-linear-attention (FLA) by Songlin Yang, Yu Zhang et al. -- the reference Triton implementations of GLA, DeltaNet, retention, and other linear attention variants that our examples test against.
-
mamba by Albert Gu and Tri Dao -- the Mamba-2 SSD Triton kernel used as baseline for the Mamba-2 example.
-
Attention Engine -- a DSL-based approach to generating chunked linear attention kernels that inspired the generalized engine design.