Skip to content

Commit 491af69

Browse files
committed
[Autotuner] Add crash recovery bash script for unrecoverable CUDA errors
Add scripts/autotune_with_crash_recovery.sh — a bash wrapper that automatically recovers from CUDA errors (illegal memory access, misaligned address, etc.) that poison the GPU context and kill the autotuning process. How it works: - Before each benchmark, the autotuner writes the current config to a pending file (_pending_config.txt) in the checkpoint directory - If a CUDA error kills the process, the pending file survives on disk - The bash script detects it, appends the poison config to _bad_configs.txt, and re-launches the command from scratch - On re-launch, the autotuner loads its checkpoint + bad configs list, skips the poison config, and continues searching Usage: scripts/autotune_with_crash_recovery.sh \ --checkpoint-dir /tmp/ckpt -- python train.py stack-info: PR: #1921, branch: yf225/stack/91
1 parent 4872e5d commit 491af69

File tree

6 files changed

+529
-4
lines changed

6 files changed

+529
-4
lines changed

helion/autotuner/base_search.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
435435
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
436436
self._precompile_args_path: str | None = None
437437
self._precompile_result_counter = count()
438+
self._bad_config_strs: set[str] = set()
438439

439440
def _prepare(self) -> None:
440441
"""Some initialization deferred until autotuning actually runs.
@@ -531,9 +532,44 @@ def _try_load_checkpoint(self) -> bool:
531532
# load_state_dict validates required keys and raises CheckpointError for issues
532533
self.load_state_dict(state)
533534

535+
# Load bad configs (from subprocess crash recovery)
536+
self._load_bad_configs()
537+
534538
self.log(f"Resumed at generation {self._current_generation}")
535539
return True
536540

541+
def _load_bad_configs(self) -> None:
542+
"""Load bad configs from _bad_configs.txt file."""
543+
from .subprocess_runner import load_bad_configs
544+
545+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
546+
if checkpoint_dir_str is not None:
547+
bad_configs_path = os.path.join(checkpoint_dir_str, "_bad_configs.txt")
548+
self._bad_config_strs |= load_bad_configs(bad_configs_path)
549+
550+
if self._bad_config_strs:
551+
self.log(
552+
f"Loaded {len(self._bad_config_strs)} bad config(s) to skip",
553+
)
554+
555+
def _write_pending_config(self, config_str: str) -> None:
556+
"""Write the config being benchmarked to the pending file."""
557+
from .subprocess_runner import write_pending
558+
559+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
560+
if checkpoint_dir_str is None:
561+
return
562+
write_pending(checkpoint_dir_str, config_str)
563+
564+
def _clear_pending_config(self) -> None:
565+
"""Remove the pending file after benchmark completes."""
566+
from .subprocess_runner import clear_pending
567+
568+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
569+
if checkpoint_dir_str is None:
570+
return
571+
clear_pending(checkpoint_dir_str)
572+
537573
def _compute_baseline(
538574
self,
539575
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -752,9 +788,16 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
752788
Returns:
753789
The performance of the configuration in ms.
754790
"""
791+
# Skip configs that previously crashed the subprocess
792+
config_str = str(config)
793+
if config_str in self._bad_config_strs:
794+
self.log.warning(f"Skipping known-bad config: {config}")
795+
return inf
796+
755797
self._autotune_metrics.num_configs_tested += 1
756798
self.counters["benchmark"] += 1
757799
self.log.debug(lambda: f"Running benchmark for {config!r}")
800+
self._write_pending_config(config_str)
758801
_captured_output: list[str] = [""]
759802
_capture_ctx = (
760803
capture_output()
@@ -794,6 +837,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
794837
if not compile_success_all:
795838
return inf
796839

840+
_is_unrecoverable = False
797841
try:
798842
# TODO(jansel): early exit with fewer trials if early runs are slow
799843
self.log.debug(lambda: f"Running {config} at {datetime.datetime.now()}")
@@ -855,6 +899,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
855899
captured_output=_captured_output[0] or None,
856900
)
857901
if match_unrecoverable_runtime_error(e):
902+
_is_unrecoverable = True
858903
self.kernel.maybe_log_repro(self.log.error, self.args, config)
859904
raise exc.TritonUnrecoverableRuntimeError(
860905
reason=str(e),
@@ -908,6 +953,9 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
908953

909954
self._autotune_metrics.num_compile_failures += 1
910955
return inf
956+
finally:
957+
if not _is_unrecoverable:
958+
self._clear_pending_config()
911959

912960
def set_adaptive_compile_timeout(
913961
self,
@@ -1193,6 +1241,8 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11931241
exit_stack.callback(self.cleanup)
11941242

11951243
if not self._try_load_checkpoint():
1244+
# Load bad configs even on fresh starts (subprocess recovery)
1245+
self._load_bad_configs()
11961246
self._init_search()
11971247
try:
11981248
best = self._autotune()
@@ -1296,6 +1346,11 @@ def _cleanup_checkpoint(self) -> None:
12961346
checkpoint_file.unlink()
12971347
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
12981348

1349+
# Clean up subprocess recovery artifacts
1350+
from .subprocess_runner import cleanup_subprocess_artifacts
1351+
1352+
cleanup_subprocess_artifacts(checkpoint_dir_str)
1353+
12991354
@staticmethod
13001355
def _serialize_numpy_rng_state(
13011356
state: tuple[str, Any, int, int, float],

helion/autotuner/logger.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,15 +466,25 @@ def format_triton_compile_failure(
466466
)
467467
)
468468

469+
# CUDA errors that poison the GPU context and require process restart.
470+
# Source: CUDA driver_types.h — all errors documented with
471+
# "To continue using CUDA, the process must be terminated and relaunched."
472+
# Substrings are matched case-insensitively against cudaGetErrorString output.
469473
_UNRECOVERABLE_RUNTIME_ERROR_RE: re.Pattern[str] = re.compile(
470474
"|".join(
471475
map(
472476
re.escape,
473477
[
474-
"illegal memory access",
475-
"misaligned address",
476-
"unspecified launch failure",
477-
"illegal instruction",
478+
"illegal memory access", # cudaErrorIllegalAddress (700)
479+
"misaligned address", # cudaErrorMisalignedAddress (716)
480+
"unspecified launch failure", # cudaErrorLaunchFailure (719)
481+
"illegal instruction", # cudaErrorIllegalInstruction (715)
482+
"device-side assert", # cudaErrorAssert (710)
483+
"hardware stack error", # cudaErrorHardwareStackError (714)
484+
"invalid program counter", # cudaErrorInvalidPc (718)
485+
"not supported on global/shared address space", # cudaErrorInvalidAddressSpace (717)
486+
"tensor memory not completely freed", # cudaErrorTensorMemoryLeak (721)
487+
"launch timed out", # cudaErrorLaunchTimeout (702)
478488
],
479489
)
480490
),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""File I/O helpers for autotuner crash recovery.
2+
3+
The crash recovery protocol works with an external retry loop
4+
(scripts/autotune_with_crash_recovery.sh). Before benchmarking each
5+
config, the autotuner writes its string representation to a pending
6+
file. If the process crashes (e.g. CUDA illegal memory access), the
7+
pending file survives and the external retry loop records it as a bad
8+
config. On re-run, the autotuner loads the checkpoint + bad configs
9+
and skips the poison config.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import os
15+
from pathlib import Path
16+
17+
_PENDING_FILENAME = "_pending_config.txt"
18+
_BAD_CONFIGS_FILENAME = "_bad_configs.txt"
19+
20+
21+
def write_pending(checkpoint_dir: str, config_str: str) -> None:
22+
"""Write the config being benchmarked to the pending file."""
23+
pending_path = Path(checkpoint_dir) / _PENDING_FILENAME
24+
pending_path.write_text(config_str)
25+
26+
27+
def clear_pending(checkpoint_dir: str) -> None:
28+
"""Remove the pending file after benchmark completes."""
29+
pending_path = Path(checkpoint_dir) / _PENDING_FILENAME
30+
if pending_path.exists():
31+
pending_path.unlink()
32+
33+
34+
def load_bad_configs(bad_configs_path: str) -> set[str]:
35+
"""Load bad config strings from file, one per line."""
36+
path = Path(bad_configs_path)
37+
if not path.exists():
38+
return set()
39+
lines = path.read_text().splitlines()
40+
return {line.strip() for line in lines if line.strip()}
41+
42+
43+
def _append_bad_config(bad_configs_path: str, config_str: str) -> None:
44+
"""Append a bad config string to the bad configs file."""
45+
with open(bad_configs_path, "a") as f:
46+
f.write(config_str + "\n")
47+
f.flush()
48+
os.fsync(f.fileno())
49+
50+
51+
def cleanup_subprocess_artifacts(checkpoint_dir: str) -> None:
52+
"""Remove crash-recovery files in the checkpoint directory."""
53+
checkpoint_path = Path(checkpoint_dir)
54+
for name in (
55+
_PENDING_FILENAME,
56+
_BAD_CONFIGS_FILENAME,
57+
):
58+
artifact = checkpoint_path / name
59+
if artifact.exists():
60+
artifact.unlink()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env bash
2+
# Autotuner crash recovery wrapper.
3+
#
4+
# Runs a command (typically a Python script that calls helion autotuning)
5+
# in a retry loop. When the process crashes due to an unrecoverable CUDA
6+
# error (illegal memory access, misaligned address, etc.), the autotuner
7+
# leaves a "_pending_config.txt" breadcrumb in the checkpoint directory.
8+
# This script detects that file, records the poison config in
9+
# "_bad_configs.txt", and re-runs the command. On re-run the autotuner
10+
# loads its checkpoint and skips the bad config.
11+
#
12+
# Progress detection:
13+
# Each crash should block a different config (since blocked configs are
14+
# skipped on re-run). If the same config crashes twice, the autotuner
15+
# is stuck and we give up.
16+
#
17+
# Requirements:
18+
# - HELION_AUTOTUNE_CHECKPOINT_DIR must be set
19+
#
20+
# Usage:
21+
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/ckpt \
22+
# scripts/autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
23+
#
24+
# Examples:
25+
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/autotune_ckpt \
26+
# scripts/autotune_with_crash_recovery.sh -- python train.py
27+
28+
set -uo pipefail
29+
30+
# --- Argument parsing ---
31+
usage() {
32+
cat >&2 <<'EOF'
33+
Usage: HELION_AUTOTUNE_CHECKPOINT_DIR=/path/to/dir \
34+
autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
35+
EOF
36+
exit "${1:-1}"
37+
}
38+
39+
while [[ $# -gt 0 ]]; do
40+
case "$1" in
41+
-h|--help)
42+
usage 0
43+
;;
44+
--)
45+
shift
46+
break
47+
;;
48+
*)
49+
echo "Error: unknown option '$1'" >&2
50+
usage 1
51+
;;
52+
esac
53+
done
54+
55+
if [[ $# -eq 0 ]]; then
56+
echo "Error: no command specified after --" >&2
57+
usage 1
58+
fi
59+
60+
if [[ -z "${HELION_AUTOTUNE_CHECKPOINT_DIR:-}" ]]; then
61+
echo "Error: HELION_AUTOTUNE_CHECKPOINT_DIR must be set." >&2
62+
exit 1
63+
fi
64+
65+
# --- Setup ---
66+
checkpoint_dir="$HELION_AUTOTUNE_CHECKPOINT_DIR"
67+
mkdir -p "$checkpoint_dir"
68+
69+
pending_file="$checkpoint_dir/_pending_config.txt"
70+
bad_configs_file="$checkpoint_dir/_bad_configs.txt"
71+
72+
# --- Retry loop ---
73+
attempt=0
74+
last_config=""
75+
76+
while true; do
77+
attempt=$((attempt + 1))
78+
79+
# Run the user command (don't use set -e, capture exit code manually)
80+
"$@"
81+
exit_code=$?
82+
83+
if [[ $exit_code -eq 0 ]]; then
84+
exit 0
85+
fi
86+
87+
# Check if the autotuner left a pending config breadcrumb
88+
if [[ -f "$pending_file" ]]; then
89+
config=$(cat "$pending_file")
90+
rm -f "$pending_file"
91+
echo "$config" >> "$bad_configs_file"
92+
93+
echo "[crash-recovery] Process crashed (exit code $exit_code, attempt $attempt)." >&2
94+
echo "[crash-recovery] Blocked config: $config" >&2
95+
96+
# If the same config crashed again, the bad config is not being
97+
# skipped — the autotuner is stuck.
98+
if [[ "$config" == "$last_config" ]]; then
99+
echo "[crash-recovery] Same config crashed twice — the autotuner appears stuck." >&2
100+
echo "[crash-recovery] All bad configs have been recorded. You can re-run this script and it will resume from the latest checkpoint, skipping all previously recorded bad configs." >&2
101+
exit 1
102+
fi
103+
last_config="$config"
104+
105+
echo "[crash-recovery] Restarting from checkpoint..." >&2
106+
else
107+
# No pending file — this is not a recoverable CUDA crash.
108+
# Propagate the original exit code.
109+
exit "$exit_code"
110+
fi
111+
done

test/data/autotune_crash_helper.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Helper script for bash crash recovery tests.
2+
3+
Run via:
4+
scripts/autotune_with_crash_recovery.sh --checkpoint-dir DIR -- python test/data/autotune_crash_helper.py
5+
6+
On first run (when _CRASH_ON_FIRST_BENCHMARK is set and no counter file
7+
exists): patches benchmark_function to crash after writing the pending file
8+
via the real code path. On subsequent runs: autotuning resumes from
9+
checkpoint normally, skipping the bad config.
10+
11+
Without _CRASH_ON_FIRST_BENCHMARK: runs autotuning normally (used to test
12+
that the bash script passes through a successful run).
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import os
18+
from pathlib import Path
19+
20+
import torch
21+
22+
checkpoint_dir = os.environ["HELION_AUTOTUNE_CHECKPOINT_DIR"]
23+
crash_on_first = os.environ.get("_CRASH_ON_FIRST_BENCHMARK", "")
24+
counter_file = Path(checkpoint_dir) / "_benchmark_counter"
25+
26+
if crash_on_first and not counter_file.exists():
27+
from helion.autotuner.base_search import BaseSearch
28+
29+
def _crashing_benchmark(self, config, fn): # type: ignore[no-untyped-def]
30+
counter_file.write_text("done")
31+
# Write pending file via the real code path
32+
self._write_pending_config(str(config))
33+
# Crash without clearing the pending file — simulates CUDA kill
34+
os._exit(1)
35+
36+
BaseSearch.benchmark_function = _crashing_benchmark # type: ignore[assignment]
37+
38+
# Import and run real autotuning
39+
from helion._testing import import_path # noqa: E402
40+
41+
datadir = Path(__file__).parent
42+
basic_kernels = import_path(datadir / "basic_kernels.py")
43+
44+
args = (torch.randn([8, 32], device="cuda"), torch.randn([8, 32], device="cuda"))
45+
bound = basic_kernels.add.bind(args)
46+
bound.settings.autotune_checkpoint_dir = checkpoint_dir
47+
bound.settings.autotune_effort = "quick"
48+
config = bound.autotune(args, force=True)
49+
result = bound(*args)
50+
torch.testing.assert_close(result, args[0] + args[1])

0 commit comments

Comments
 (0)