Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/linear/GRID_BUG_INVESTIGATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Investigation: IMA with `hl.grid(N)` and dynamic N

## Summary

`chunk_bwd_dh_diag_fused` and similar kernels using `hl.grid(N)` where N is
a dynamic (non-specialized) value produce an illegal memory access (IMA) when:

1. The kernel is first compiled for one value of C (e.g. C=32, giving N=T/C=4)
2. Then called with a different C (e.g. C=64, giving N=32)
3. The second call is exercised rapidly via `do_bench`

The crash does **not** reproduce:
- With `CUDA_LAUNCH_BLOCKING=1` (serialized kernel launches)
- When the kernel is called in isolation (even 100x rapidly)
- When only one value of C is used throughout the process
- In a minimal standalone repro of the same kernel pattern

The crash **does** reproduce consistently when the full backward pipeline
runs after a test phase that compiled with different C.

## Workaround

Specializing N with `hl.specialize()` fixes the IMA:

```python
# Before (crashes):
N = q_scaled.size(1)

# After (works):
N = hl.specialize(q_scaled.size(1))
```

This causes a recompilation per distinct N value, which is acceptable since
N = T/C typically takes a small number of values per model configuration.

## Generated code analysis

The generated Triton kernel receives N as a dynamic argument and uses it in:

```python
for offset_3 in tl.range(0, tl.cast(N, tl.int32), ...):
sub_1 = -1 + N + -1 * offset_3 # i.e. N - 1 - i_t
tl.store(dh_all + sub_1 * dh_all_stride_1 + ..., ...)
```

The loop bounds and indexing look correct. However, the kernel also has
`_RDIM_SIZE_2: tl.constexpr` parameters that are set to
`triton.next_power_of_2(dh_init.size(1))` — these control the size of
`tl.arange` used for the D dimension.

When the kernel is compiled for shape1 (D=32 → `_RDIM_SIZE_2=32`) and then
a new compilation happens for shape2 (D=128 → `_RDIM_SIZE_2=128`), both
compilations share the same source but differ in constexpr values. The
suspicion is that something in the Triton compilation or caching interacts
badly when `hl.grid(N)` has a dynamic range **and** the kernel has multiple
constexpr-specialized variants coexisting.

## What's NOT the cause

- **tensor_descriptor indexing**: crash reproduces with all-pointer indexing
- **persistent_blocked pid_type**: crash reproduces with flat pid_type
- **block_sizes**: crash reproduces with every tested block_size value
- **Any single config parameter**: every parameter variant crashes; only
the presence/absence of a prior compilation with different C matters

## Affected kernels

All four kernels using `hl.grid(N)` with unspecialized N:

- `chunk_fwd_h_diag_fused` (line 336)
- `chunk_fwd_phase1_diag_fused` (line 409)
- `chunk_bwd_dh_diag_fused` (line 490)
- `chunk_bwd_dh_correction_diag_fused` (line 524)

## Likely root cause

The interaction between:
1. `hl.grid(N)` producing `tl.range(0, N)` with dynamic N
2. Multiple constexpr-specialized compilations of the same kernel source
3. Rapid repeated execution via `do_bench` without full synchronization

may cause Triton to reuse a compiled kernel variant with incorrect grid
bounds or rdim sizes. This needs investigation at the Helion/Triton level,
potentially with `compute-sanitizer --tool memcheck` to identify the exact
out-of-bounds access.
108 changes: 108 additions & 0 deletions examples/linear/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Linear Attention Examples

Helion implementations of chunked linear attention covering the major linear
attention / state-space model variants. Every example includes correctness
tests (forward, backward, and recurrent step) against both a pure-PyTorch
reference **and** [flash-linear-attention](https://github.com/fla-org/flash-linear-attention)
(FLA) / [mamba](https://github.com/state-spaces/mamba), plus benchmarks
comparing Helion kernel performance against those libraries.

This work builds on ideas from the
[Attention Engine](https://github.com/fla-org/attention-engine) project.

## Dependencies

The examples require the following optional packages for reference comparisons
and benchmarks:

```bash
# flash-linear-attention (FLA) — reference for most variants
pip install fla

# mamba — reference for Mamba-2 SSD only
pip install mamba_ssm
```

If these are not installed, the examples will still run their core correctness
tests against the pure-PyTorch reference, but FLA/mamba comparisons and
benchmarks will be skipped with a warning.

## Variants

| Example | Decay | Correction | FLA baseline |
|---------|-------|------------|--------------|
| `example_simple_gla.py` | scalar | none | `chunk_simple_gla` |
| `example_full_gla.py` | diagonal | none | `chunk_gla` |
| `example_delta_rule.py` | none | rank-1 | `chunk_delta_rule` |
| `example_gated_delta_rule.py` | scalar | rank-1 | `chunk_gated_delta_rule` |
| `example_vanilla_linear_attn.py` | none | none | `chunk_linear_attn` |
| `example_retention.py` | scalar (fixed) | none | `chunk_retention` |
| `example_mamba2_ssd.py` | scalar | none | `mamba_chunk_scan_combined` |
| `example_rwkv6.py` | diagonal | none (+ output gate) | `chunk_rwkv6` |
| `example_kda.py` | diagonal | rank-1 | `chunk_kda` |

## Architecture

**`linear_attention_engine.py`** -- All Helion kernels (16 `@helion.experimental.aot_kernel()`
functions including fused recurrent step), `ChunkedLinearAttnFn` autograd wrapper,
`LinearAttentionEngine` class, and the `chunked_linear_attn()` / `recurrent_step()`
public entry points.

**`linear_attention_utils.py`** -- Pure-PyTorch chunked reference implementation,
naive recurrent reference, WY decomposition helpers, input generators.

## Running

```bash
# Run all examples (test + benchmark):
python -m examples.linear.all

# Single example:
python -m examples.linear.example_simple_gla

# Run all tests via pytest:
pytest test/test_examples.py -k "test_linear" -v

# Run monkey-patch tests (plugs our engine into FLA layers):
pytest test/test_examples.py -k "monkeypatch" -v
```

## AOT autotuning

The kernels use `@helion.experimental.aot_kernel()` and are ready for
ahead-of-time autotuning.

**Note:** `HELION_AUTOTUNE_IGNORE_ERRORS=1` is required because certain
autotuner-generated configs (specifically `flatten_loops=True` for kernels
using `.sum()` reductions) trigger a Helion codegen bug. These configs are
automatically skipped when ignore-errors is enabled.

```bash
# Quick tuning of all examples:
HELION_AUTOTUNE_PRECOMPILE=spawn \
HELION_AUTOTUNE_IGNORE_ERRORS=1 \
HELION_AUTOTUNE_MAX_GENERATIONS=5 \
python -m examples.linear.all

# Full AOT pipeline via the runner:
HELION_AUTOTUNE_PRECOMPILE=spawn \
HELION_AUTOTUNE_IGNORE_ERRORS=1 \
python -m helion.experimental.aot_runner --phase all \
-- python -m examples.linear.all

```

## Acknowledgements

- [flash-linear-attention](https://github.com/fla-org/flash-linear-attention)
(FLA) by Songlin Yang, Yu Zhang *et al.* -- the reference Triton
implementations of GLA, DeltaNet, retention, and other linear attention
variants that our examples test against.

- [mamba](https://github.com/state-spaces/mamba) by Albert Gu and
Tri Dao -- the Mamba-2 SSD Triton kernel used as baseline for the
Mamba-2 example.

- [Attention Engine](https://github.com/fla-org/attention-engine) --
a DSL-based approach to generating chunked linear attention kernels
that inspired the generalized engine design.
Empty file added examples/linear/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions examples/linear/all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Run all linear attention examples (test + benchmark).

Usage::

python -m examples.linear.all
HELION_USE_DEFAULT_CONFIG=1 python -m examples.linear.all
"""

from __future__ import annotations

import importlib
import sys
import traceback

EXAMPLES = [
"example_simple_gla",
"example_full_gla",
"example_delta_rule",
"example_gated_delta_rule",
"example_vanilla_linear_attn",
"example_retention",
"example_mamba2_ssd",
"example_rwkv6",
"example_kda",
]


def main() -> None:
results: list[tuple[str, str]] = []

for name in EXAMPLES:
print(f"\n{'=' * 70}")
print(f" {name}")
print(f"{'=' * 70}")
try:
mod = importlib.import_module(f"examples.linear.{name}")
mod.main()
results.append((name, "OK"))
except Exception:
traceback.print_exc()
results.append((name, "FAIL"))

print(f"\n{'=' * 70}")
print(" Summary")
print(f"{'=' * 70}")
for name, status in results:
print(f" {name:<40} {status}")

ok = sum(1 for _, s in results if s == "OK")
print(f"\n{ok}/{len(results)} passed")

if ok < len(results):
sys.exit(1)


if __name__ == "__main__":
main()
Loading
Loading