Skip to content

Commit 5490390

Browse files
committed
[Autotuner] Add crash recovery script for unrecoverable CUDA errors
stack-info: PR: #1923, branch: yf225/stack/93
1 parent 4872e5d commit 5490390

File tree

4 files changed

+348
-0
lines changed

4 files changed

+348
-0
lines changed

helion/autotuner/base_search.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
435435
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
436436
self._precompile_args_path: str | None = None
437437
self._precompile_result_counter = count()
438+
self._crashed_config_strs: set[str] = set()
438439

439440
def _prepare(self) -> None:
440441
"""Some initialization deferred until autotuning actually runs.
@@ -534,6 +535,26 @@ def _try_load_checkpoint(self) -> bool:
534535
self.log(f"Resumed at generation {self._current_generation}")
535536
return True
536537

538+
def _load_crashed_configs(self) -> None:
539+
"""Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
540+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
541+
if checkpoint_dir_str is None:
542+
return
543+
crashed_configs_path = (
544+
Path(checkpoint_dir_str)
545+
/ f"{self._get_stable_hash()}.crashed_configs"
546+
)
547+
if crashed_configs_path.exists():
548+
self._crashed_config_strs |= {
549+
line.strip()
550+
for line in crashed_configs_path.read_text().splitlines()
551+
if line.strip()
552+
}
553+
if self._crashed_config_strs:
554+
self.log(
555+
f"Loaded {len(self._crashed_config_strs)} crashed config(s) to skip"
556+
)
557+
537558
def _compute_baseline(
538559
self,
539560
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -752,6 +773,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
752773
Returns:
753774
The performance of the configuration in ms.
754775
"""
776+
# Skip configs that previously crashed the subprocess
777+
config_str = str(config)
778+
if config_str in self._crashed_config_strs:
779+
self.log.warning(f"Skipping known-crashed config: {config}")
780+
return inf
781+
755782
self._autotune_metrics.num_configs_tested += 1
756783
self.counters["benchmark"] += 1
757784
self.log.debug(lambda: f"Running benchmark for {config!r}")
@@ -1016,10 +1043,23 @@ def _benchmark(
10161043
fns: list[Callable[..., object]] = []
10171044
valid_configs: list[Config] = []
10181045
futures: list[PrecompileFuture] | None = None
1046+
# Compute pending config path once for breadcrumb writes.
1047+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1048+
pending_path = (
1049+
Path(checkpoint_dir_str) / f"{self._get_stable_hash()}.pending_config"
1050+
if checkpoint_dir_str is not None
1051+
else None
1052+
)
10191053
for i, config in enumerate(configs):
1054+
# Write breadcrumb before compile so a hard crash (SIGKILL /
1055+
# CUDA IMA) leaves a trace the bash recovery script can find.
1056+
if pending_path is not None:
1057+
pending_path.write_text(str(config))
10201058
try:
10211059
fn = self.kernel.compile_config(config, allow_print=False)
10221060
except Exception:
1061+
if pending_path is not None:
1062+
pending_path.unlink(missing_ok=True)
10231063
# If all configs failed, raise error
10241064
if not valid_configs and i == len(configs) - 1:
10251065
raise
@@ -1029,6 +1069,8 @@ def _benchmark(
10291069
exc_info=True,
10301070
)
10311071
continue
1072+
if pending_path is not None:
1073+
pending_path.unlink(missing_ok=True)
10321074
fns.append(fn)
10331075
valid_configs.append(config)
10341076
configs = valid_configs
@@ -1089,7 +1131,14 @@ def _benchmark(
10891131
)
10901132
)
10911133
# benchmark one-by-one to avoid noisy results
1134+
# Write pending-config breadcrumb; cleared after benchmark.
1135+
# On crash the file stays so the bash recovery script can
1136+
# detect which config caused the failure.
1137+
if pending_path is not None:
1138+
pending_path.write_text(str(config))
10921139
perf = self.benchmark_function(config, fn)
1140+
if pending_path is not None:
1141+
pending_path.unlink(missing_ok=True)
10931142
status = "ok" if math.isfinite(perf) else "error"
10941143
# Log completion after benchmarking
10951144
self.log.record_autotune_entry(
@@ -1194,6 +1243,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11941243

11951244
if not self._try_load_checkpoint():
11961245
self._init_search()
1246+
self._load_crashed_configs()
11971247
try:
11981248
best = self._autotune()
11991249
self._cleanup_checkpoint()
@@ -1296,6 +1346,12 @@ def _cleanup_checkpoint(self) -> None:
12961346
checkpoint_file.unlink()
12971347
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
12981348

1349+
# Clean up crash-recovery artifacts
1350+
for suffix in (".pending_config", ".crashed_configs"):
1351+
artifact = Path(checkpoint_dir_str) / f"{stable_hash}{suffix}"
1352+
if artifact.exists():
1353+
artifact.unlink()
1354+
12991355
@staticmethod
13001356
def _serialize_numpy_rng_state(
13011357
state: tuple[str, Any, int, int, float],
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Autotuner crash recovery wrapper.
2+
3+
Runs a command (typically a Python script that calls helion autotuning) in a
4+
retry loop. When the process crashes due to an unrecoverable CUDA error
5+
(illegal memory access, misaligned address, etc.), the autotuner leaves a
6+
``{hash}.pending_config`` breadcrumb in the checkpoint directory. This script
7+
detects that file, records the poison config in ``{hash}.crashed_configs``, and
8+
re-runs the command. On re-run the autotuner loads its checkpoint and skips
9+
the crashed config.
10+
11+
Progress detection
12+
------------------
13+
Each crash should block a different config (since blocked configs are skipped
14+
on re-run). If the same config crashes twice, the autotuner is stuck and we
15+
give up.
16+
17+
Requirements
18+
------------
19+
``HELION_AUTOTUNE_CHECKPOINT_DIR`` must be set in the environment.
20+
21+
Usage
22+
-----
23+
::
24+
25+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/ckpt \\
26+
python -m helion.experimental.crash_recovery -- COMMAND [ARGS...]
27+
28+
Examples
29+
--------
30+
::
31+
32+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/autotune_ckpt \\
33+
python -m helion.experimental.crash_recovery -- python train.py
34+
"""
35+
36+
from __future__ import annotations
37+
38+
import argparse
39+
import glob
40+
import os
41+
import subprocess
42+
import sys
43+
from pathlib import Path
44+
45+
46+
def _log(msg: str) -> None:
47+
print(f"[crash-recovery] {msg}", file=sys.stderr)
48+
49+
50+
def main(argv: list[str] | None = None) -> int:
51+
parser = argparse.ArgumentParser(
52+
description="Autotuner crash recovery wrapper.",
53+
usage=(
54+
"HELION_AUTOTUNE_CHECKPOINT_DIR=/path/to/dir\n"
55+
" %(prog)s -- COMMAND [ARGS...]"
56+
),
57+
)
58+
parser.add_argument(
59+
"command",
60+
nargs=argparse.REMAINDER,
61+
help="Command to run (after '--' separator)",
62+
)
63+
args = parser.parse_args(argv)
64+
65+
# argparse.REMAINDER absorbs '--' as first element when present.
66+
command: list[str] = args.command
67+
if command and command[0] == "--":
68+
command = command[1:]
69+
if not command:
70+
parser.error("no command specified after --")
71+
72+
checkpoint_dir_str = os.environ.get("HELION_AUTOTUNE_CHECKPOINT_DIR", "")
73+
if not checkpoint_dir_str:
74+
print(
75+
"Error: HELION_AUTOTUNE_CHECKPOINT_DIR must be set.",
76+
file=sys.stderr,
77+
)
78+
return 1
79+
80+
checkpoint_dir = Path(checkpoint_dir_str)
81+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
82+
83+
attempt = 0
84+
last_config = ""
85+
86+
while True:
87+
attempt += 1
88+
89+
result = subprocess.run(command)
90+
exit_code = result.returncode
91+
92+
if exit_code == 0:
93+
return 0
94+
95+
# Look for any *.pending_config breadcrumb left by the autotuner.
96+
pending_files = sorted(glob.glob(str(checkpoint_dir / "*.pending_config")))
97+
98+
if pending_files:
99+
# Process the first pending file found (typically only one exists
100+
# because the process crashed while working on a single config).
101+
pending_path = Path(pending_files[0])
102+
hash_prefix = pending_path.stem # {hash} without .pending_config
103+
crashed_configs_path = checkpoint_dir / f"{hash_prefix}.crashed_configs"
104+
105+
config = pending_path.read_text().strip()
106+
pending_path.unlink()
107+
108+
with open(crashed_configs_path, "a") as f:
109+
f.write(config + "\n")
110+
111+
_log(f"Process crashed (exit code {exit_code}, attempt {attempt}).")
112+
_log(f"Blocked config: {config}")
113+
114+
# If the same config crashed again, the crashed config is not
115+
# being skipped -- the autotuner is stuck.
116+
if config == last_config:
117+
_log(
118+
"Same config crashed twice \u2014 the autotuner appears stuck."
119+
)
120+
_log(
121+
"All crashed configs have been recorded. You can re-run "
122+
"this script and it will resume from the latest "
123+
"checkpoint, skipping all previously recorded crashed "
124+
"configs."
125+
)
126+
return 1
127+
last_config = config
128+
129+
_log("Restarting from checkpoint...")
130+
else:
131+
# No pending file -- not a recoverable CUDA crash.
132+
return exit_code
133+
134+
135+
if __name__ == "__main__":
136+
sys.exit(main())

test/data/autotune_crash_helper.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Helper script for crash recovery tests.
2+
3+
Run via:
4+
HELION_AUTOTUNE_CHECKPOINT_DIR=DIR \
5+
python -m helion.experimental.crash_recovery -- python test/data/autotune_crash_helper.py
6+
7+
On first run (when _CRASH_ON_FIRST_BENCHMARK or _CRASH_ON_FIRST_COMPILE is
8+
set and no counter file exists): patches do_bench / compile_config to trigger
9+
a hard crash, which exercises the pending_config breadcrumb and the crash
10+
recovery script. On subsequent runs: autotuning resumes from checkpoint
11+
normally, skipping the crashed config.
12+
13+
Without the crash env vars: runs autotuning normally (used to test that the
14+
bash script passes through a successful run).
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import os
20+
from pathlib import Path
21+
22+
import torch
23+
24+
checkpoint_dir = os.environ["HELION_AUTOTUNE_CHECKPOINT_DIR"]
25+
crash_on_first_benchmark = os.environ.get("_CRASH_ON_FIRST_BENCHMARK", "")
26+
crash_on_first_compile = os.environ.get("_CRASH_ON_FIRST_COMPILE", "")
27+
counter_file = Path(checkpoint_dir) / "_crash_counter"
28+
29+
if crash_on_first_benchmark and not counter_file.exists():
30+
import triton
31+
import triton.language as tl
32+
33+
import helion.autotuner.base_search as _bs
34+
35+
@triton.jit
36+
def _ima_kernel(ptr):
37+
"""Triton kernel that triggers illegal memory access."""
38+
bad_ptr = ptr + (1 << 40)
39+
tl.store(bad_ptr, tl.full([], 42.0, dtype=tl.float32))
40+
41+
_original_do_bench = _bs.do_bench
42+
43+
def _ima_do_bench(*args, **kwargs): # type: ignore[no-untyped-def]
44+
counter_file.write_text("done")
45+
# Restore original so this only fires once
46+
_bs.do_bench = _original_do_bench
47+
# Trigger real CUDA illegal memory access
48+
x = torch.zeros(1, device="cuda")
49+
_ima_kernel[(1,)](x)
50+
torch.cuda.synchronize()
51+
# Should not reach here — IMA raises an exception
52+
return _original_do_bench(*args, **kwargs)
53+
54+
_bs.do_bench = _ima_do_bench
55+
56+
if crash_on_first_compile and not counter_file.exists():
57+
import helion.autotuner.base_search as _bs
58+
59+
# Wrap _benchmark so the real breadcrumb-writing code runs, but
60+
# compile_config triggers a hard crash (os._exit) on first call.
61+
_original_benchmark = _bs.BaseSearch._benchmark
62+
63+
def _crashing_benchmark(self, configs, **kwargs): # type: ignore[no-untyped-def]
64+
_orig_compile = self.kernel.compile_config
65+
66+
def _crash_compile(*args, **kw): # type: ignore[no-untyped-def]
67+
counter_file.write_text("done")
68+
# Simulate a hard crash (SIGKILL / CUDA IMA) during
69+
# compile_config. os._exit bypasses all Python exception
70+
# handling including try/except.
71+
os._exit(1)
72+
73+
self.kernel.compile_config = _crash_compile
74+
return _original_benchmark(self, configs, **kwargs)
75+
76+
_bs.BaseSearch._benchmark = _crashing_benchmark # type: ignore[assignment]
77+
78+
# Import and run real autotuning
79+
from helion._testing import import_path # noqa: E402
80+
81+
datadir = Path(__file__).parent
82+
basic_kernels = import_path(datadir / "basic_kernels.py")
83+
84+
args = (torch.randn([8, 32], device="cuda"), torch.randn([8, 32], device="cuda"))
85+
bound = basic_kernels.add.bind(args)
86+
bound.settings.autotune_checkpoint_dir = checkpoint_dir
87+
bound.settings.autotune_effort = "quick"
88+
config = bound.autotune(args, force=True)
89+
result = bound(*args)
90+
torch.testing.assert_close(result, args[0] + args[1])

0 commit comments

Comments
 (0)