Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/deployment_autotuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ automatically. On successful completion, the checkpoint file is cleaned up.

```bash
# Enable checkpointing to a directory:
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_checkpoints python run_kernel.py

# If interrupted, just re-run with the same directory to resume:
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_checkpoints python run_kernel.py
```

Without `HELION_AUTOTUNE_CHECKPOINT_DIR`, no checkpoints are saved (opt-in).
Expand Down
45 changes: 44 additions & 1 deletion helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
self._precompile_args_path: str | None = None
self._precompile_result_counter = count()
self._crashed_config_strs: set[str] = set()

def _prepare(self) -> None:
"""Some initialization deferred until autotuning actually runs.
Expand Down Expand Up @@ -739,6 +740,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
Returns:
The performance of the configuration in ms.
"""
# Skip configs that previously crashed the subprocess
config_str = str(config)
if config_str in self._crashed_config_strs:
self.log.warning(f"Skipping known-crashed config: {config}")
return inf

self._autotune_metrics.num_configs_tested += 1
self.log.debug(lambda: f"Running benchmark for {config!r}")
_captured_output: list[str] = [""]
Expand Down Expand Up @@ -1018,13 +1025,36 @@ def _benchmark(
A list of BenchmarkResult entries containing the configuration, compiled
callable, measured performance, status, and compilation time.
"""
# Filter out known-crashed configs before compilation
if self._crashed_config_strs:
original_len = len(configs)
configs = [c for c in configs if str(c) not in self._crashed_config_strs]
skipped = original_len - len(configs)
if skipped:
self.log.warning(
f"Skipped {skipped} known-crashed config(s) before compilation"
)
if not configs:
return []

fns: list[Callable[..., object]] = []
valid_configs: list[Config] = []
futures: list[PrecompileFuture] | None = None
pending_path = self._get_pending_config_path()
for i, config in enumerate(configs):
# Write sentinel before compile so a hard crash (SIGKILL /
# CUDA IMA) leaves a trace the crash recovery script can find.
if pending_path is not None:
pending_path.write_text(str(config))
try:
fn = self.kernel.compile_config(config, allow_print=False)
except Exception:
except Exception as e:
if match_unrecoverable_runtime_error(e):
# Leave sentinel for crash recovery — CUDA context is
# corrupted and the process cannot continue.
raise
if pending_path is not None:
pending_path.unlink(missing_ok=True)
# If all configs failed, raise error
if not valid_configs and i == len(configs) - 1:
raise
Expand All @@ -1034,9 +1064,14 @@ def _benchmark(
exc_info=True,
)
continue
if pending_path is not None:
pending_path.unlink(missing_ok=True)
fns.append(fn)
valid_configs.append(config)
configs = valid_configs
# NOTE: precompile runs in separate subprocesses with isolated CUDA
# contexts; crashes there are caught via is_working checks, not
# sentinels.
if self.settings.autotune_precompile:
futures = list(
starmap(
Expand Down Expand Up @@ -1098,7 +1133,14 @@ def _benchmark(
)
)
# benchmark one-by-one to avoid noisy results
# Write pending-config sentinel; cleared after benchmark.
# On crash the file stays so the crash recovery script can
# detect which config caused the failure.
if pending_path is not None:
pending_path.write_text(str(config))
perf = self.benchmark_function(config, fn)
if pending_path is not None:
pending_path.unlink(missing_ok=True)
status = "ok" if math.isfinite(perf) else "error"
# Log completion after benchmarking
self.log.record_autotune_entry(
Expand Down Expand Up @@ -1204,6 +1246,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
checkpoint_enabled = self.settings.autotune_checkpoint_dir is not None
if not (checkpoint_enabled and self._try_load_checkpoint()):
self._init_search()
self._load_crashed_configs()
try:
best = self._autotune()
if checkpoint_enabled:
Expand Down
146 changes: 146 additions & 0 deletions helion/experimental/crash_recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Autotuner crash recovery wrapper.

Runs a command (typically a Python script that calls helion autotuning) in a
retry loop. When the process crashes due to an unrecoverable CUDA error
(illegal memory access, misaligned address, etc.), the autotuner leaves a
``{hash}.pending_config`` sentinel in the checkpoint directory. This script
detects that file, records the poison config in ``{hash}.crashed_configs``, and
re-runs the command. On re-run the autotuner loads its checkpoint and skips
the crashed config.

Progress detection
------------------
Each crash should block a different config (since blocked configs are skipped
on re-run). If the same config crashes twice, the autotuner is stuck and we
give up.

Requirements
------------
``HELION_AUTOTUNE_CHECKPOINT_DIR`` must be set in the environment.

Usage
-----
::

HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_ckpt \\
python -m helion.experimental.crash_recovery [--max-retries N] -- COMMAND [ARGS...]

Examples
--------
::

HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_autotune_ckpt \\
python -m helion.experimental.crash_recovery -- python train.py
"""

from __future__ import annotations

import argparse
import os
from pathlib import Path
import subprocess
import sys


def _log(msg: str) -> None:
print(f"[crash-recovery] {msg}", file=sys.stderr)


def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description="Autotuner crash recovery wrapper.",
usage=(
"HELION_AUTOTUNE_CHECKPOINT_DIR=/path/to/dir\n"
" %(prog)s [--max-retries N] -- COMMAND [ARGS...]"
),
)
parser.add_argument(
"--max-retries",
type=int,
default=50,
help="Maximum number of crash recovery retries (default: 50)",
)
parser.add_argument(
"command",
nargs=argparse.REMAINDER,
help="Command to run (after '--' separator)",
)
args = parser.parse_args(argv)

# argparse.REMAINDER absorbs '--' as first element when present.
command: list[str] = args.command
if command and command[0] == "--":
command = command[1:]
if not command:
parser.error("no command specified after --")

checkpoint_dir_str = os.environ.get("HELION_AUTOTUNE_CHECKPOINT_DIR", "")
if not checkpoint_dir_str:
print(
"Error: HELION_AUTOTUNE_CHECKPOINT_DIR must be set.",
file=sys.stderr,
)
return 1

checkpoint_dir = Path(checkpoint_dir_str)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

attempt = 0
all_crashed: set[str] = set()

while True:
attempt += 1

result = subprocess.run(command)
exit_code = result.returncode

if exit_code == 0:
return 0

# Look for any *.pending_config sentinel left by the autotuner.
pending_files = sorted(checkpoint_dir.glob("*.pending_config"))

if pending_files:
stuck = False
for pending_path in pending_files:
hash_prefix = pending_path.stem # {hash} without .pending_config
crashed_configs_path = checkpoint_dir / f"{hash_prefix}.crashed_configs"

config = pending_path.read_text().strip()
pending_path.unlink()

with open(crashed_configs_path, "a") as f:
f.write(config + "\n")

_log(f"Blocked config: {config}")

# If this config was already blocked in a previous attempt,
# the autotuner is not skipping it -- it's stuck.
if config in all_crashed:
stuck = True
all_crashed.add(config)

_log(f"Process crashed (exit code {exit_code}, attempt {attempt}).")

if stuck:
_log("Same config crashed twice \u2014 the autotuner appears stuck.")
_log(
"All crashed configs have been recorded. You can re-run "
"this script and it will resume from the latest "
"checkpoint, skipping all previously recorded crashed "
"configs."
)
return 1

if attempt >= args.max_retries:
_log(f"Reached maximum retry limit ({args.max_retries}). Giving up.")
return 1

_log("Restarting from checkpoint...")
else:
# No pending file -- not a recoverable CUDA crash.
return exit_code


if __name__ == "__main__":
sys.exit(main())
101 changes: 101 additions & 0 deletions test/data/autotune_crash_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Helper script for crash recovery tests.

Run via:
HELION_AUTOTUNE_CHECKPOINT_DIR=DIR \
python -m helion.experimental.crash_recovery -- python test/data/autotune_crash_helper.py

On first run (when _CRASH_ON_FIRST_BENCHMARK or _CRASH_ON_FIRST_COMPILE is
set and no counter file exists): patches do_bench / compile_config to trigger
a hard crash, which exercises the pending_config sentinel and the crash
recovery script. On subsequent runs: autotuning resumes from checkpoint
normally, skipping the crashed config.

Without the crash env vars: runs autotuning normally (used to test that the
crash recovery script passes through a successful run).
"""

from __future__ import annotations

import os
from pathlib import Path

import torch

from helion._testing import DEVICE

checkpoint_dir = os.environ["HELION_AUTOTUNE_CHECKPOINT_DIR"]
crash_on_first_benchmark = os.environ.get("_CRASH_ON_FIRST_BENCHMARK", "")
crash_on_first_compile = os.environ.get("_CRASH_ON_FIRST_COMPILE", "")
counter_file = Path(checkpoint_dir) / "_crash_counter"

if crash_on_first_benchmark and not counter_file.exists():
import triton
import triton.language as tl

import helion.autotuner.base_search as _bs

@triton.jit
def _ima_kernel(ptr):
"""Triton kernel that triggers illegal memory access."""
bad_ptr = ptr + (1 << 40)
tl.store(bad_ptr, tl.full([], 42.0, dtype=tl.float32))

_original_do_bench = _bs.do_bench

def _ima_do_bench(*args, **kwargs): # type: ignore[no-untyped-def]
counter_file.write_text("done")
# Restore original so this only fires once
_bs.do_bench = _original_do_bench
# Trigger real CUDA illegal memory access
x = torch.zeros(1, device=DEVICE)
_ima_kernel[(1,)](x)
torch.cuda.synchronize()
# Should not reach here — IMA raises an exception
return _original_do_bench(*args, **kwargs)

_bs.do_bench = _ima_do_bench

if crash_on_first_compile and not counter_file.exists():
import triton
import triton.language as tl

import helion.autotuner.base_search as _bs

@triton.jit
def _ima_kernel_compile(ptr):
"""Triton kernel that triggers illegal memory access."""
bad_ptr = ptr + (1 << 40)
tl.store(bad_ptr, tl.full([], 42.0, dtype=tl.float32))

# Wrap _benchmark so the real sentinel-writing code runs, but
# compile_config triggers a real CUDA IMA on first call.
# base_search._benchmark now detects unrecoverable errors and
# preserves the sentinel instead of cleaning it up.
_original_benchmark = _bs.BaseSearch._benchmark

def _crashing_benchmark(self, configs, **kwargs): # type: ignore[no-untyped-def]
def _crash_compile(*args, **kw): # type: ignore[no-untyped-def]
counter_file.write_text("done")
# Trigger real CUDA illegal memory access during compile
x = torch.zeros(1, device=DEVICE)
_ima_kernel_compile[(1,)](x)
torch.cuda.synchronize()

self.kernel.compile_config = _crash_compile
return _original_benchmark(self, configs, **kwargs)

_bs.BaseSearch._benchmark = _crashing_benchmark # type: ignore[assignment]

# Import and run real autotuning
from helion._testing import import_path # noqa: E402

datadir = Path(__file__).parent
basic_kernels = import_path(datadir / "basic_kernels.py")

args = (torch.randn([8, 32], device=DEVICE), torch.randn([8, 32], device=DEVICE))
bound = basic_kernels.add.bind(args)
bound.settings.autotune_checkpoint_dir = checkpoint_dir
bound.settings.autotune_effort = "quick"
config = bound.autotune(args, force=True)
result = bound(*args)
torch.testing.assert_close(result, args[0] + args[1])
1 change: 1 addition & 0 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _make_search(
search._precompile_args_path = None
search._precompile_result_counter = count()
search._prepared = True
search._crashed_config_strs = set()
return search

def test_settings_flag_from_env(self):
Expand Down
Loading
Loading