Skip to content

Commit 0453c35

Browse files
committed
[Autotuner] Add crash recovery script for unrecoverable CUDA errors
stack-info: PR: #1923, branch: yf225/stack/93
1 parent 16d3fb7 commit 0453c35

File tree

6 files changed

+381
-3
lines changed

6 files changed

+381
-3
lines changed

docs/deployment_autotuning.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ automatically. On successful completion, the checkpoint file is cleaned up.
196196

197197
```bash
198198
# Enable checkpointing to a directory:
199-
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
199+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_checkpoints python run_kernel.py
200200

201201
# If interrupted, just re-run with the same directory to resume:
202-
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
202+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_checkpoints python run_kernel.py
203203
```
204204

205205
Without `HELION_AUTOTUNE_CHECKPOINT_DIR`, no checkpoints are saved (opt-in).

helion/autotuner/base_search.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
351351
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
352352
self._precompile_args_path: str | None = None
353353
self._precompile_result_counter = count()
354+
self._crashed_config_strs: set[str] = set()
354355

355356
def _prepare(self) -> None:
356357
"""Some initialization deferred until autotuning actually runs.
@@ -494,6 +495,32 @@ def _try_load_checkpoint(self) -> bool:
494495

495496
def _recompile_after_checkpoint(self) -> None:
496497
"""Recompile after loading a checkpoint. Override in subclasses."""
498+
499+
def _load_crashed_configs(self) -> None:
500+
"""Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
501+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
502+
if checkpoint_dir_str is None:
503+
return
504+
crashed_configs_path = (
505+
Path(checkpoint_dir_str) / f"{self._get_stable_hash()}.crashed_configs"
506+
)
507+
if crashed_configs_path.exists():
508+
self._crashed_config_strs |= {
509+
line.strip()
510+
for line in crashed_configs_path.read_text().splitlines()
511+
if line.strip()
512+
}
513+
if self._crashed_config_strs:
514+
self.log(
515+
f"Loaded {len(self._crashed_config_strs)} crashed config(s) to skip"
516+
)
517+
518+
def _get_pending_config_path(self) -> Path | None:
519+
"""Get path for pending-config sentinel, or None if checkpointing disabled."""
520+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
521+
if checkpoint_dir_str is None:
522+
return None
523+
return Path(checkpoint_dir_str) / f"{self._get_stable_hash()}.pending_config"
497524
def _compute_baseline(
498525
self,
499526
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -716,6 +743,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
716743
Returns:
717744
The performance of the configuration in ms.
718745
"""
746+
# Skip configs that previously crashed the subprocess
747+
config_str = str(config)
748+
if config_str in self._crashed_config_strs:
749+
self.log.warning(f"Skipping known-crashed config: {config}")
750+
return inf
751+
719752
self._autotune_metrics.num_configs_tested += 1
720753
self.counters["benchmark"] += 1
721754
self.log.debug(lambda: f"Running benchmark for {config!r}")
@@ -996,13 +1029,36 @@ def _benchmark(
9961029
A list of BenchmarkResult entries containing the configuration, compiled
9971030
callable, measured performance, status, and compilation time.
9981031
"""
1032+
# Filter out known-crashed configs before compilation
1033+
if self._crashed_config_strs:
1034+
original_len = len(configs)
1035+
configs = [c for c in configs if str(c) not in self._crashed_config_strs]
1036+
skipped = original_len - len(configs)
1037+
if skipped:
1038+
self.log.warning(
1039+
f"Skipped {skipped} known-crashed config(s) before compilation"
1040+
)
1041+
if not configs:
1042+
return []
1043+
9991044
fns: list[Callable[..., object]] = []
10001045
valid_configs: list[Config] = []
10011046
futures: list[PrecompileFuture] | None = None
1047+
pending_path = self._get_pending_config_path()
10021048
for i, config in enumerate(configs):
1049+
# Write sentinel before compile so a hard crash (SIGKILL /
1050+
# CUDA IMA) leaves a trace the crash recovery script can find.
1051+
if pending_path is not None:
1052+
pending_path.write_text(str(config))
10031053
try:
10041054
fn = self.kernel.compile_config(config, allow_print=False)
1005-
except Exception:
1055+
except Exception as e:
1056+
if match_unrecoverable_runtime_error(e):
1057+
# Leave sentinel for crash recovery — CUDA context is
1058+
# corrupted and the process cannot continue.
1059+
raise
1060+
if pending_path is not None:
1061+
pending_path.unlink(missing_ok=True)
10061062
# If all configs failed, raise error
10071063
if not valid_configs and i == len(configs) - 1:
10081064
raise
@@ -1012,9 +1068,14 @@ def _benchmark(
10121068
exc_info=True,
10131069
)
10141070
continue
1071+
if pending_path is not None:
1072+
pending_path.unlink(missing_ok=True)
10151073
fns.append(fn)
10161074
valid_configs.append(config)
10171075
configs = valid_configs
1076+
# NOTE: precompile runs in separate subprocesses with isolated CUDA
1077+
# contexts; crashes there are caught via is_working checks, not
1078+
# sentinels.
10181079
if self.settings.autotune_precompile:
10191080
futures = list(
10201081
starmap(
@@ -1076,7 +1137,14 @@ def _benchmark(
10761137
)
10771138
)
10781139
# benchmark one-by-one to avoid noisy results
1140+
# Write pending-config sentinel; cleared after benchmark.
1141+
# On crash the file stays so the crash recovery script can
1142+
# detect which config caused the failure.
1143+
if pending_path is not None:
1144+
pending_path.write_text(str(config))
10791145
perf = self.benchmark_function(config, fn)
1146+
if pending_path is not None:
1147+
pending_path.unlink(missing_ok=True)
10801148
status = "ok" if math.isfinite(perf) else "error"
10811149
# Log completion after benchmarking
10821150
self.log.record_autotune_entry(
@@ -1181,6 +1249,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11811249

11821250
if not self._try_load_checkpoint():
11831251
self._init_search()
1252+
self._load_crashed_configs()
11841253
try:
11851254
best = self._autotune()
11861255
self._cleanup_checkpoint()
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Autotuner crash recovery wrapper.
2+
3+
Runs a command (typically a Python script that calls helion autotuning) in a
4+
retry loop. When the process crashes due to an unrecoverable CUDA error
5+
(illegal memory access, misaligned address, etc.), the autotuner leaves a
6+
``{hash}.pending_config`` sentinel in the checkpoint directory. This script
7+
detects that file, records the poison config in ``{hash}.crashed_configs``, and
8+
re-runs the command. On re-run the autotuner loads its checkpoint and skips
9+
the crashed config.
10+
11+
Progress detection
12+
------------------
13+
Each crash should block a different config (since blocked configs are skipped
14+
on re-run). If the same config crashes twice, the autotuner is stuck and we
15+
give up.
16+
17+
Requirements
18+
------------
19+
``HELION_AUTOTUNE_CHECKPOINT_DIR`` must be set in the environment.
20+
21+
Usage
22+
-----
23+
::
24+
25+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_ckpt \\
26+
python -m helion.experimental.crash_recovery [--max-retries N] -- COMMAND [ARGS...]
27+
28+
Examples
29+
--------
30+
::
31+
32+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/$USER/helion_autotune_ckpt \\
33+
python -m helion.experimental.crash_recovery -- python train.py
34+
"""
35+
36+
from __future__ import annotations
37+
38+
import argparse
39+
import os
40+
from pathlib import Path
41+
import subprocess
42+
import sys
43+
44+
45+
def _log(msg: str) -> None:
46+
print(f"[crash-recovery] {msg}", file=sys.stderr)
47+
48+
49+
def main(argv: list[str] | None = None) -> int:
50+
parser = argparse.ArgumentParser(
51+
description="Autotuner crash recovery wrapper.",
52+
usage=(
53+
"HELION_AUTOTUNE_CHECKPOINT_DIR=/path/to/dir\n"
54+
" %(prog)s [--max-retries N] -- COMMAND [ARGS...]"
55+
),
56+
)
57+
parser.add_argument(
58+
"--max-retries",
59+
type=int,
60+
default=50,
61+
help="Maximum number of crash recovery retries (default: 50)",
62+
)
63+
parser.add_argument(
64+
"command",
65+
nargs=argparse.REMAINDER,
66+
help="Command to run (after '--' separator)",
67+
)
68+
args = parser.parse_args(argv)
69+
70+
# argparse.REMAINDER absorbs '--' as first element when present.
71+
command: list[str] = args.command
72+
if command and command[0] == "--":
73+
command = command[1:]
74+
if not command:
75+
parser.error("no command specified after --")
76+
77+
checkpoint_dir_str = os.environ.get("HELION_AUTOTUNE_CHECKPOINT_DIR", "")
78+
if not checkpoint_dir_str:
79+
print(
80+
"Error: HELION_AUTOTUNE_CHECKPOINT_DIR must be set.",
81+
file=sys.stderr,
82+
)
83+
return 1
84+
85+
checkpoint_dir = Path(checkpoint_dir_str)
86+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
87+
88+
attempt = 0
89+
all_crashed: set[str] = set()
90+
91+
while True:
92+
attempt += 1
93+
94+
result = subprocess.run(command)
95+
exit_code = result.returncode
96+
97+
if exit_code == 0:
98+
return 0
99+
100+
# Look for any *.pending_config sentinel left by the autotuner.
101+
pending_files = sorted(checkpoint_dir.glob("*.pending_config"))
102+
103+
if pending_files:
104+
stuck = False
105+
for pending_path in pending_files:
106+
hash_prefix = pending_path.stem # {hash} without .pending_config
107+
crashed_configs_path = checkpoint_dir / f"{hash_prefix}.crashed_configs"
108+
109+
config = pending_path.read_text().strip()
110+
pending_path.unlink()
111+
112+
with open(crashed_configs_path, "a") as f:
113+
f.write(config + "\n")
114+
115+
_log(f"Blocked config: {config}")
116+
117+
# If this config was already blocked in a previous attempt,
118+
# the autotuner is not skipping it -- it's stuck.
119+
if config in all_crashed:
120+
stuck = True
121+
all_crashed.add(config)
122+
123+
_log(f"Process crashed (exit code {exit_code}, attempt {attempt}).")
124+
125+
if stuck:
126+
_log("Same config crashed twice \u2014 the autotuner appears stuck.")
127+
_log(
128+
"All crashed configs have been recorded. You can re-run "
129+
"this script and it will resume from the latest "
130+
"checkpoint, skipping all previously recorded crashed "
131+
"configs."
132+
)
133+
return 1
134+
135+
if attempt >= args.max_retries:
136+
_log(f"Reached maximum retry limit ({args.max_retries}). Giving up.")
137+
return 1
138+
139+
_log("Restarting from checkpoint...")
140+
else:
141+
# No pending file -- not a recoverable CUDA crash.
142+
return exit_code
143+
144+
145+
if __name__ == "__main__":
146+
sys.exit(main())

test/data/autotune_crash_helper.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Helper script for crash recovery tests.
2+
3+
Run via:
4+
HELION_AUTOTUNE_CHECKPOINT_DIR=DIR \
5+
python -m helion.experimental.crash_recovery -- python test/data/autotune_crash_helper.py
6+
7+
On first run (when _CRASH_ON_FIRST_BENCHMARK or _CRASH_ON_FIRST_COMPILE is
8+
set and no counter file exists): patches do_bench / compile_config to trigger
9+
a hard crash, which exercises the pending_config sentinel and the crash
10+
recovery script. On subsequent runs: autotuning resumes from checkpoint
11+
normally, skipping the crashed config.
12+
13+
Without the crash env vars: runs autotuning normally (used to test that the
14+
crash recovery script passes through a successful run).
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import os
20+
from pathlib import Path
21+
22+
import torch
23+
24+
checkpoint_dir = os.environ["HELION_AUTOTUNE_CHECKPOINT_DIR"]
25+
crash_on_first_benchmark = os.environ.get("_CRASH_ON_FIRST_BENCHMARK", "")
26+
crash_on_first_compile = os.environ.get("_CRASH_ON_FIRST_COMPILE", "")
27+
counter_file = Path(checkpoint_dir) / "_crash_counter"
28+
29+
if crash_on_first_benchmark and not counter_file.exists():
30+
import triton
31+
import triton.language as tl
32+
33+
import helion.autotuner.base_search as _bs
34+
35+
@triton.jit
36+
def _ima_kernel(ptr):
37+
"""Triton kernel that triggers illegal memory access."""
38+
bad_ptr = ptr + (1 << 40)
39+
tl.store(bad_ptr, tl.full([], 42.0, dtype=tl.float32))
40+
41+
_original_do_bench = _bs.do_bench
42+
43+
def _ima_do_bench(*args, **kwargs): # type: ignore[no-untyped-def]
44+
counter_file.write_text("done")
45+
# Restore original so this only fires once
46+
_bs.do_bench = _original_do_bench
47+
# Trigger real CUDA illegal memory access
48+
x = torch.zeros(1, device="cuda")
49+
_ima_kernel[(1,)](x)
50+
torch.cuda.synchronize()
51+
# Should not reach here — IMA raises an exception
52+
return _original_do_bench(*args, **kwargs)
53+
54+
_bs.do_bench = _ima_do_bench
55+
56+
if crash_on_first_compile and not counter_file.exists():
57+
import triton
58+
import triton.language as tl
59+
60+
import helion.autotuner.base_search as _bs
61+
62+
@triton.jit
63+
def _ima_kernel_compile(ptr):
64+
"""Triton kernel that triggers illegal memory access."""
65+
bad_ptr = ptr + (1 << 40)
66+
tl.store(bad_ptr, tl.full([], 42.0, dtype=tl.float32))
67+
68+
# Wrap _benchmark so the real sentinel-writing code runs, but
69+
# compile_config triggers a real CUDA IMA on first call.
70+
# base_search._benchmark now detects unrecoverable errors and
71+
# preserves the sentinel instead of cleaning it up.
72+
_original_benchmark = _bs.BaseSearch._benchmark
73+
74+
def _crashing_benchmark(self, configs, **kwargs): # type: ignore[no-untyped-def]
75+
def _crash_compile(*args, **kw): # type: ignore[no-untyped-def]
76+
counter_file.write_text("done")
77+
# Trigger real CUDA illegal memory access during compile
78+
x = torch.zeros(1, device="cuda")
79+
_ima_kernel_compile[(1,)](x)
80+
torch.cuda.synchronize()
81+
82+
self.kernel.compile_config = _crash_compile
83+
return _original_benchmark(self, configs, **kwargs)
84+
85+
_bs.BaseSearch._benchmark = _crashing_benchmark # type: ignore[assignment]
86+
87+
# Import and run real autotuning
88+
from helion._testing import import_path # noqa: E402
89+
90+
datadir = Path(__file__).parent
91+
basic_kernels = import_path(datadir / "basic_kernels.py")
92+
93+
args = (torch.randn([8, 32], device="cuda"), torch.randn([8, 32], device="cuda"))
94+
bound = basic_kernels.add.bind(args)
95+
bound.settings.autotune_checkpoint_dir = checkpoint_dir
96+
bound.settings.autotune_effort = "quick"
97+
config = bound.autotune(args, force=True)
98+
result = bound(*args)
99+
torch.testing.assert_close(result, args[0] + args[1])

test/test_autotuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def _init_search(self) -> None:
124124
search._precompile_result_counter = count()
125125
search._prepared = True
126126
search.counters = collections.Counter()
127+
search._crashed_config_strs = set()
127128
return search
128129

129130
def test_settings_flag_from_env(self):

0 commit comments

Comments
 (0)