Skip to content

Commit d2ee5a6

Browse files
committed
[Autotuner] Add crash recovery bash script for unrecoverable CUDA errors
1 parent 4872e5d commit d2ee5a6

File tree

4 files changed

+313
-1
lines changed

4 files changed

+313
-1
lines changed

helion/autotuner/base_search.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from helion._dist_utils import is_symm_mem_tensor
7575

7676
if TYPE_CHECKING:
77+
from collections.abc import Iterator
7778
from collections.abc import Sequence
7879

7980
from ..runtime.kernel import BoundKernel
@@ -435,6 +436,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
435436
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
436437
self._precompile_args_path: str | None = None
437438
self._precompile_result_counter = count()
439+
self._bad_config_strs: set[str] = set()
438440

439441
def _prepare(self) -> None:
440442
"""Some initialization deferred until autotuning actually runs.
@@ -534,6 +536,41 @@ def _try_load_checkpoint(self) -> bool:
534536
self.log(f"Resumed at generation {self._current_generation}")
535537
return True
536538

539+
def _load_bad_configs(self) -> None:
540+
"""Load bad configs from _bad_configs.txt (written by crash-recovery script)."""
541+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
542+
if checkpoint_dir_str is None:
543+
return
544+
bad_configs_path = Path(checkpoint_dir_str) / "_bad_configs.txt"
545+
if bad_configs_path.exists():
546+
self._bad_config_strs |= {
547+
line.strip()
548+
for line in bad_configs_path.read_text().splitlines()
549+
if line.strip()
550+
}
551+
if self._bad_config_strs:
552+
self.log(f"Loaded {len(self._bad_config_strs)} bad config(s) to skip")
553+
554+
@contextlib.contextmanager
555+
def _pending_config(self, config: Config) -> Iterator[None]:
556+
"""Write a pending-config breadcrumb before benchmarking, clear it after.
557+
558+
If the body raises TritonUnrecoverableRuntimeError the pending file
559+
is intentionally *not* cleared so the crash-recovery script can detect it.
560+
"""
561+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
562+
if checkpoint_dir_str is None:
563+
yield
564+
return
565+
pending_path = Path(checkpoint_dir_str) / "_pending_config.txt"
566+
pending_path.write_text(str(config))
567+
try:
568+
yield
569+
except exc.TritonUnrecoverableRuntimeError: # noqa: TRY203
570+
raise
571+
else:
572+
pending_path.unlink(missing_ok=True)
573+
537574
def _compute_baseline(
538575
self,
539576
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -752,6 +789,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
752789
Returns:
753790
The performance of the configuration in ms.
754791
"""
792+
# Skip configs that previously crashed the subprocess
793+
config_str = str(config)
794+
if config_str in self._bad_config_strs:
795+
self.log.warning(f"Skipping known-bad config: {config}")
796+
return inf
797+
755798
self._autotune_metrics.num_configs_tested += 1
756799
self.counters["benchmark"] += 1
757800
self.log.debug(lambda: f"Running benchmark for {config!r}")
@@ -1089,7 +1132,8 @@ def _benchmark(
10891132
)
10901133
)
10911134
# benchmark one-by-one to avoid noisy results
1092-
perf = self.benchmark_function(config, fn)
1135+
with self._pending_config(config):
1136+
perf = self.benchmark_function(config, fn)
10931137
status = "ok" if math.isfinite(perf) else "error"
10941138
# Log completion after benchmarking
10951139
self.log.record_autotune_entry(
@@ -1194,6 +1238,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11941238

11951239
if not self._try_load_checkpoint():
11961240
self._init_search()
1241+
self._load_bad_configs()
11971242
try:
11981243
best = self._autotune()
11991244
self._cleanup_checkpoint()
@@ -1296,6 +1341,12 @@ def _cleanup_checkpoint(self) -> None:
12961341
checkpoint_file.unlink()
12971342
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
12981343

1344+
# Clean up crash-recovery artifacts
1345+
for name in ("_pending_config.txt", "_bad_configs.txt"):
1346+
artifact = Path(checkpoint_dir_str) / name
1347+
if artifact.exists():
1348+
artifact.unlink()
1349+
12991350
@staticmethod
13001351
def _serialize_numpy_rng_state(
13011352
state: tuple[str, Any, int, int, float],
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env bash
2+
# Autotuner crash recovery wrapper.
3+
#
4+
# Runs a command (typically a Python script that calls helion autotuning)
5+
# in a retry loop. When the process crashes due to an unrecoverable CUDA
6+
# error (illegal memory access, misaligned address, etc.), the autotuner
7+
# leaves a "_pending_config.txt" breadcrumb in the checkpoint directory.
8+
# This script detects that file, records the poison config in
9+
# "_bad_configs.txt", and re-runs the command. On re-run the autotuner
10+
# loads its checkpoint and skips the bad config.
11+
#
12+
# Progress detection:
13+
# Each crash should block a different config (since blocked configs are
14+
# skipped on re-run). If the same config crashes twice, the autotuner
15+
# is stuck and we give up.
16+
#
17+
# Requirements:
18+
# - HELION_AUTOTUNE_CHECKPOINT_DIR must be set
19+
#
20+
# Usage:
21+
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/ckpt \
22+
# scripts/autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
23+
#
24+
# Examples:
25+
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/autotune_ckpt \
26+
# scripts/autotune_with_crash_recovery.sh -- python train.py
27+
28+
set -uo pipefail
29+
30+
# --- Argument parsing ---
31+
usage() {
32+
cat >&2 <<'EOF'
33+
Usage: HELION_AUTOTUNE_CHECKPOINT_DIR=/path/to/dir \
34+
autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
35+
EOF
36+
exit "${1:-1}"
37+
}
38+
39+
while [[ $# -gt 0 ]]; do
40+
case "$1" in
41+
-h|--help)
42+
usage 0
43+
;;
44+
--)
45+
shift
46+
break
47+
;;
48+
*)
49+
echo "Error: unknown option '$1'" >&2
50+
usage 1
51+
;;
52+
esac
53+
done
54+
55+
if [[ $# -eq 0 ]]; then
56+
echo "Error: no command specified after --" >&2
57+
usage 1
58+
fi
59+
60+
if [[ -z "${HELION_AUTOTUNE_CHECKPOINT_DIR:-}" ]]; then
61+
echo "Error: HELION_AUTOTUNE_CHECKPOINT_DIR must be set." >&2
62+
exit 1
63+
fi
64+
65+
# --- Setup ---
66+
checkpoint_dir="$HELION_AUTOTUNE_CHECKPOINT_DIR"
67+
mkdir -p "$checkpoint_dir"
68+
69+
pending_file="$checkpoint_dir/_pending_config.txt"
70+
bad_configs_file="$checkpoint_dir/_bad_configs.txt"
71+
72+
# --- Retry loop ---
73+
attempt=0
74+
last_config=""
75+
76+
while true; do
77+
attempt=$((attempt + 1))
78+
79+
# Run the user command (don't use set -e, capture exit code manually)
80+
"$@"
81+
exit_code=$?
82+
83+
if [[ $exit_code -eq 0 ]]; then
84+
exit 0
85+
fi
86+
87+
# Check if the autotuner left a pending config breadcrumb
88+
if [[ -f "$pending_file" ]]; then
89+
config=$(cat "$pending_file")
90+
rm -f "$pending_file"
91+
echo "$config" >> "$bad_configs_file"
92+
93+
echo "[crash-recovery] Process crashed (exit code $exit_code, attempt $attempt)." >&2
94+
echo "[crash-recovery] Blocked config: $config" >&2
95+
96+
# If the same config crashed again, the bad config is not being
97+
# skipped — the autotuner is stuck.
98+
if [[ "$config" == "$last_config" ]]; then
99+
echo "[crash-recovery] Same config crashed twice — the autotuner appears stuck." >&2
100+
echo "[crash-recovery] All bad configs have been recorded. You can re-run this script and it will resume from the latest checkpoint, skipping all previously recorded bad configs." >&2
101+
exit 1
102+
fi
103+
last_config="$config"
104+
105+
echo "[crash-recovery] Restarting from checkpoint..." >&2
106+
else
107+
# No pending file — this is not a recoverable CUDA crash.
108+
# Propagate the original exit code.
109+
exit "$exit_code"
110+
fi
111+
done

test/data/autotune_crash_helper.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Helper script for bash crash recovery tests.
2+
3+
Run via:
4+
HELION_AUTOTUNE_CHECKPOINT_DIR=DIR \
5+
scripts/autotune_with_crash_recovery.sh -- python test/data/autotune_crash_helper.py
6+
7+
On first run (when _CRASH_ON_FIRST_BENCHMARK is set and no counter file
8+
exists): patches do_bench to trigger a real CUDA illegal memory access,
9+
which exercises the real _pending_config context manager and
10+
TritonUnrecoverableRuntimeError code path. On subsequent runs: autotuning
11+
resumes from checkpoint normally, skipping the bad config.
12+
13+
Without _CRASH_ON_FIRST_BENCHMARK: runs autotuning normally (used to test
14+
that the bash script passes through a successful run).
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import os
20+
from pathlib import Path
21+
22+
import torch
23+
24+
checkpoint_dir = os.environ["HELION_AUTOTUNE_CHECKPOINT_DIR"]
25+
crash_on_first = os.environ.get("_CRASH_ON_FIRST_BENCHMARK", "")
26+
counter_file = Path(checkpoint_dir) / "_benchmark_counter"
27+
28+
if crash_on_first and not counter_file.exists():
29+
import triton
30+
import triton.language as tl
31+
32+
import helion.autotuner.base_search as _bs
33+
34+
@triton.jit
35+
def _ima_kernel(ptr):
36+
"""Triton kernel that triggers illegal memory access."""
37+
bad_ptr = ptr + (1 << 40)
38+
tl.store(bad_ptr, tl.full([], 42.0, dtype=tl.float32))
39+
40+
_original_do_bench = _bs.do_bench
41+
42+
def _ima_do_bench(*args, **kwargs): # type: ignore[no-untyped-def]
43+
counter_file.write_text("done")
44+
# Restore original so this only fires once
45+
_bs.do_bench = _original_do_bench
46+
# Trigger real CUDA illegal memory access
47+
x = torch.zeros(1, device="cuda")
48+
_ima_kernel[(1,)](x)
49+
torch.cuda.synchronize()
50+
# Should not reach here — IMA raises an exception
51+
return _original_do_bench(*args, **kwargs)
52+
53+
_bs.do_bench = _ima_do_bench
54+
55+
# Import and run real autotuning
56+
from helion._testing import import_path # noqa: E402
57+
58+
datadir = Path(__file__).parent
59+
basic_kernels = import_path(datadir / "basic_kernels.py")
60+
61+
args = (torch.randn([8, 32], device="cuda"), torch.randn([8, 32], device="cuda"))
62+
bound = basic_kernels.add.bind(args)
63+
bound.settings.autotune_checkpoint_dir = checkpoint_dir
64+
bound.settings.autotune_effort = "quick"
65+
config = bound.autotune(args, force=True)
66+
result = bound(*args)
67+
torch.testing.assert_close(result, args[0] + args[1])

test/test_autotuner_subprocess.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from pathlib import Path
5+
import subprocess
6+
import tempfile
7+
8+
import pytest
9+
10+
SCRIPT = str(
11+
Path(__file__).parent.parent / "scripts" / "autotune_with_crash_recovery.sh"
12+
)
13+
HELPER = str(Path(__file__).parent / "data" / "autotune_crash_helper.py")
14+
15+
16+
class TestBashCrashRecoveryScript:
17+
"""Tests for scripts/autotune_with_crash_recovery.sh.
18+
19+
These invoke the bash script via subprocess.run(). The crash recovery
20+
test uses test/data/autotune_crash_helper.py which monkey-patches
21+
do_bench to trigger a real CUDA illegal memory access, exercising the
22+
real _pending_config context manager and TritonUnrecoverableRuntimeError
23+
code path.
24+
"""
25+
26+
def _run_script(
27+
self,
28+
tmp_path: Path,
29+
cmd: list[str],
30+
extra_env: dict[str, str] | None = None,
31+
) -> subprocess.CompletedProcess[str]:
32+
"""Helper to run the bash script with HELION_AUTOTUNE_CHECKPOINT_DIR set."""
33+
env = {**os.environ, "HELION_AUTOTUNE_CHECKPOINT_DIR": str(tmp_path)}
34+
if extra_env:
35+
env.update(extra_env)
36+
return subprocess.run(
37+
[SCRIPT, "--", *cmd], capture_output=True, text=True, env=env
38+
)
39+
40+
def test_normal_exit(self, tmp_path: Path) -> None:
41+
"""Successful command passes through exit 0."""
42+
r = self._run_script(tmp_path, ["python", "-c", "pass"])
43+
assert r.returncode == 0
44+
45+
def test_no_pending_propagates_error(self, tmp_path: Path) -> None:
46+
"""Non-CUDA crash (no pending file) propagates exit code."""
47+
r = self._run_script(tmp_path, ["python", "-c", "import sys; sys.exit(42)"])
48+
assert r.returncode == 42
49+
50+
@pytest.mark.timeout(120)
51+
def test_real_autotune_through_bash(self) -> None:
52+
"""End-to-end: real autotuning succeeds through the bash script."""
53+
with tempfile.TemporaryDirectory() as tmpdir:
54+
r = self._run_script(
55+
Path(tmpdir),
56+
["python", HELPER],
57+
extra_env={
58+
"HELION_AUTOTUNE_MAX_GENERATIONS": "1",
59+
"HELION_AUTOTUNER": "PatternSearch",
60+
},
61+
)
62+
assert r.returncode == 0, f"stderr: {r.stderr}"
63+
64+
@pytest.mark.timeout(120)
65+
def test_real_crash_recovery_through_bash(self) -> None:
66+
"""End-to-end: first run crashes during real benchmarking via
67+
monkey-patch, bash script detects pending file, records bad config,
68+
re-runs. Second run resumes from checkpoint and succeeds."""
69+
with tempfile.TemporaryDirectory() as tmpdir:
70+
r = self._run_script(
71+
Path(tmpdir),
72+
["python", HELPER],
73+
extra_env={
74+
"_CRASH_ON_FIRST_BENCHMARK": "1",
75+
"HELION_AUTOTUNE_MAX_GENERATIONS": "1",
76+
"HELION_AUTOTUNER": "PatternSearch",
77+
},
78+
)
79+
assert r.returncode == 0, f"stderr: {r.stderr}"
80+
# Verify crash recovery happened (bad_configs.txt is cleaned
81+
# up by _cleanup_checkpoint on success, so check stderr)
82+
assert "[crash-recovery]" in r.stderr
83+
assert "Blocked config:" in r.stderr

0 commit comments

Comments
 (0)