Skip to content

Commit f8e06ac

Browse files
committed
[Autotuner] Add crash recovery bash script for unrecoverable CUDA errors
Add scripts/autotune_with_crash_recovery.sh — a bash wrapper that automatically recovers from CUDA errors (illegal memory access, misaligned address, etc.) that poison the GPU context and kill the autotuning process. How it works: - Before each benchmark, the autotuner writes the current config to a pending file (_pending_config.txt) in the checkpoint directory - If a CUDA error kills the process, the pending file survives on disk - The bash script detects it, appends the poison config to _bad_configs.txt, and re-launches the command from scratch - On re-launch, the autotuner loads its checkpoint + bad configs list, skips the poison config, and continues searching Usage: scripts/autotune_with_crash_recovery.sh \ --checkpoint-dir /tmp/ckpt -- python train.py stack-info: PR: #1921, branch: yf225/stack/91
1 parent 4872e5d commit f8e06ac

File tree

6 files changed

+615
-4
lines changed

6 files changed

+615
-4
lines changed

helion/autotuner/base_search.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
435435
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
436436
self._precompile_args_path: str | None = None
437437
self._precompile_result_counter = count()
438+
self._bad_config_strs: set[str] = set()
438439

439440
def _prepare(self) -> None:
440441
"""Some initialization deferred until autotuning actually runs.
@@ -531,9 +532,53 @@ def _try_load_checkpoint(self) -> bool:
531532
# load_state_dict validates required keys and raises CheckpointError for issues
532533
self.load_state_dict(state)
533534

535+
# Load bad configs (from subprocess crash recovery)
536+
self._load_bad_configs()
537+
534538
self.log(f"Resumed at generation {self._current_generation}")
535539
return True
536540

541+
def _load_bad_configs(self) -> None:
542+
"""Load bad configs from _bad_configs.txt file."""
543+
from .subprocess_runner import load_bad_configs
544+
545+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
546+
if checkpoint_dir_str is not None:
547+
bad_configs_path = os.path.join(checkpoint_dir_str, "_bad_configs.txt")
548+
self._bad_config_strs |= load_bad_configs(bad_configs_path)
549+
550+
if self._bad_config_strs:
551+
self.log(
552+
f"Loaded {len(self._bad_config_strs)} bad config(s) to skip",
553+
)
554+
555+
def _write_pending_config(self, config_str: str) -> None:
556+
"""Write the config being benchmarked to the pending file."""
557+
from .subprocess_runner import write_pending
558+
559+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
560+
if checkpoint_dir_str is None:
561+
return
562+
write_pending(checkpoint_dir_str, config_str)
563+
564+
def _clear_pending_config(self) -> None:
565+
"""Remove the pending file after benchmark completes."""
566+
from .subprocess_runner import clear_pending
567+
568+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
569+
if checkpoint_dir_str is None:
570+
return
571+
clear_pending(checkpoint_dir_str)
572+
573+
def _bump_progress(self) -> None:
574+
"""Increment the configs-tested counter for crash recovery progress tracking."""
575+
from .subprocess_runner import bump_progress
576+
577+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
578+
if checkpoint_dir_str is None:
579+
return
580+
bump_progress(checkpoint_dir_str)
581+
537582
def _compute_baseline(
538583
self,
539584
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -752,9 +797,16 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
752797
Returns:
753798
The performance of the configuration in ms.
754799
"""
800+
# Skip configs that previously crashed the subprocess
801+
config_str = str(config)
802+
if config_str in self._bad_config_strs:
803+
self.log.warning(f"Skipping known-bad config: {config}")
804+
return inf
805+
755806
self._autotune_metrics.num_configs_tested += 1
756807
self.counters["benchmark"] += 1
757808
self.log.debug(lambda: f"Running benchmark for {config!r}")
809+
self._write_pending_config(config_str)
758810
_captured_output: list[str] = [""]
759811
_capture_ctx = (
760812
capture_output()
@@ -794,6 +846,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
794846
if not compile_success_all:
795847
return inf
796848

849+
_is_unrecoverable = False
797850
try:
798851
# TODO(jansel): early exit with fewer trials if early runs are slow
799852
self.log.debug(lambda: f"Running {config} at {datetime.datetime.now()}")
@@ -855,6 +908,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
855908
captured_output=_captured_output[0] or None,
856909
)
857910
if match_unrecoverable_runtime_error(e):
911+
_is_unrecoverable = True
858912
self.kernel.maybe_log_repro(self.log.error, self.args, config)
859913
raise exc.TritonUnrecoverableRuntimeError(
860914
reason=str(e),
@@ -908,6 +962,10 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
908962

909963
self._autotune_metrics.num_compile_failures += 1
910964
return inf
965+
finally:
966+
if not _is_unrecoverable:
967+
self._clear_pending_config()
968+
self._bump_progress()
911969

912970
def set_adaptive_compile_timeout(
913971
self,
@@ -1193,6 +1251,8 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11931251
exit_stack.callback(self.cleanup)
11941252

11951253
if not self._try_load_checkpoint():
1254+
# Load bad configs even on fresh starts (subprocess recovery)
1255+
self._load_bad_configs()
11961256
self._init_search()
11971257
try:
11981258
best = self._autotune()
@@ -1296,6 +1356,11 @@ def _cleanup_checkpoint(self) -> None:
12961356
checkpoint_file.unlink()
12971357
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
12981358

1359+
# Clean up subprocess recovery artifacts
1360+
from .subprocess_runner import cleanup_subprocess_artifacts
1361+
1362+
cleanup_subprocess_artifacts(checkpoint_dir_str)
1363+
12991364
@staticmethod
13001365
def _serialize_numpy_rng_state(
13011366
state: tuple[str, Any, int, int, float],

helion/autotuner/logger.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,15 +466,25 @@ def format_triton_compile_failure(
466466
)
467467
)
468468

469+
# CUDA errors that poison the GPU context and require process restart.
470+
# Source: CUDA driver_types.h — all errors documented with
471+
# "To continue using CUDA, the process must be terminated and relaunched."
472+
# Substrings are matched case-insensitively against cudaGetErrorString output.
469473
_UNRECOVERABLE_RUNTIME_ERROR_RE: re.Pattern[str] = re.compile(
470474
"|".join(
471475
map(
472476
re.escape,
473477
[
474-
"illegal memory access",
475-
"misaligned address",
476-
"unspecified launch failure",
477-
"illegal instruction",
478+
"illegal memory access", # cudaErrorIllegalAddress (700)
479+
"misaligned address", # cudaErrorMisalignedAddress (716)
480+
"unspecified launch failure", # cudaErrorLaunchFailure (719)
481+
"illegal instruction", # cudaErrorIllegalInstruction (715)
482+
"device-side assert", # cudaErrorAssert (710)
483+
"hardware stack error", # cudaErrorHardwareStackError (714)
484+
"invalid program counter", # cudaErrorInvalidPc (718)
485+
"not supported on global/shared address space", # cudaErrorInvalidAddressSpace (717)
486+
"tensor memory not completely freed", # cudaErrorTensorMemoryLeak (721)
487+
"launch timed out", # cudaErrorLaunchTimeout (702)
478488
],
479489
)
480490
),
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""File I/O helpers for autotuner crash recovery.
2+
3+
The crash recovery protocol works with an external retry loop
4+
(scripts/autotune_with_crash_recovery.sh). Before benchmarking each
5+
config, the autotuner writes its string representation to a pending
6+
file. If the process crashes (e.g. CUDA illegal memory access), the
7+
pending file survives and the external retry loop records it as a bad
8+
config. On re-run, the autotuner loads the checkpoint + bad configs
9+
and skips the poison config.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import os
15+
from pathlib import Path
16+
17+
_PENDING_FILENAME = "_pending_config.txt"
18+
_BAD_CONFIGS_FILENAME = "_bad_configs.txt"
19+
20+
21+
def write_pending(checkpoint_dir: str, config_str: str) -> None:
22+
"""Write the config being benchmarked to the pending file."""
23+
pending_path = Path(checkpoint_dir) / _PENDING_FILENAME
24+
pending_path.write_text(config_str)
25+
26+
27+
def clear_pending(checkpoint_dir: str) -> None:
28+
"""Remove the pending file after benchmark completes."""
29+
pending_path = Path(checkpoint_dir) / _PENDING_FILENAME
30+
if pending_path.exists():
31+
pending_path.unlink()
32+
33+
34+
def load_bad_configs(bad_configs_path: str) -> set[str]:
35+
"""Load bad config strings from file, one per line."""
36+
path = Path(bad_configs_path)
37+
if not path.exists():
38+
return set()
39+
lines = path.read_text().splitlines()
40+
return {line.strip() for line in lines if line.strip()}
41+
42+
43+
def _append_bad_config(bad_configs_path: str, config_str: str) -> None:
44+
"""Append a bad config string to the bad configs file."""
45+
with open(bad_configs_path, "a") as f:
46+
f.write(config_str + "\n")
47+
f.flush()
48+
os.fsync(f.fileno())
49+
50+
51+
_PROGRESS_FILENAME = "_configs_tested.txt"
52+
53+
54+
def bump_progress(checkpoint_dir: str) -> None:
55+
"""Increment the configs-tested counter.
56+
57+
Called after each benchmark completes (success or recoverable error)
58+
without crashing the process. The bash crash-recovery script reads
59+
this counter to detect whether the autotuner is making progress.
60+
"""
61+
progress_path = Path(checkpoint_dir) / _PROGRESS_FILENAME
62+
count = 0
63+
if progress_path.exists():
64+
with open(progress_path) as f:
65+
try:
66+
count = int(f.read().strip())
67+
except ValueError:
68+
pass
69+
progress_path.write_text(str(count + 1))
70+
71+
72+
def read_progress(checkpoint_dir: str) -> int:
73+
"""Read the configs-tested counter. Returns 0 if file doesn't exist."""
74+
progress_path = Path(checkpoint_dir) / _PROGRESS_FILENAME
75+
if not progress_path.exists():
76+
return 0
77+
try:
78+
return int(progress_path.read_text().strip())
79+
except ValueError:
80+
return 0
81+
82+
83+
def cleanup_subprocess_artifacts(checkpoint_dir: str) -> None:
84+
"""Remove crash-recovery files in the checkpoint directory."""
85+
checkpoint_path = Path(checkpoint_dir)
86+
for name in (
87+
_PENDING_FILENAME,
88+
_BAD_CONFIGS_FILENAME,
89+
_PROGRESS_FILENAME,
90+
):
91+
artifact = checkpoint_path / name
92+
if artifact.exists():
93+
artifact.unlink()
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#!/usr/bin/env bash
2+
# Autotuner crash recovery wrapper.
3+
#
4+
# Runs a command (typically a Python script that calls helion autotuning)
5+
# in a retry loop. When the process crashes due to an unrecoverable CUDA
6+
# error (illegal memory access, misaligned address, etc.), the autotuner
7+
# leaves a "_pending_config.txt" breadcrumb in the checkpoint directory.
8+
# This script detects that file, records the poison config in
9+
# "_bad_configs.txt", and re-runs the command. On re-run the autotuner
10+
# loads its checkpoint and skips the bad config.
11+
#
12+
# Progress detection:
13+
# The autotuner writes a counter to _configs_tested.txt after each
14+
# successful benchmark. This script checks whether the counter advanced
15+
# between crashes. If it did, the autotuner is making progress and we
16+
# keep retrying indefinitely. If the counter doesn't advance for 3
17+
# consecutive crashes, the autotuner is stuck and we give up.
18+
#
19+
# Requirements:
20+
# - HELION_AUTOTUNE_CHECKPOINT_DIR must be set
21+
#
22+
# Usage:
23+
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/ckpt \
24+
# scripts/autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
25+
#
26+
# Examples:
27+
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/autotune_ckpt \
28+
# scripts/autotune_with_crash_recovery.sh -- python train.py
29+
30+
set -uo pipefail
31+
32+
MAX_NO_PROGRESS=3
33+
34+
# --- Argument parsing ---
35+
usage() {
36+
cat >&2 <<'EOF'
37+
Usage: HELION_AUTOTUNE_CHECKPOINT_DIR=/path/to/dir \
38+
autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
39+
EOF
40+
exit "${1:-1}"
41+
}
42+
43+
while [[ $# -gt 0 ]]; do
44+
case "$1" in
45+
-h|--help)
46+
usage 0
47+
;;
48+
--)
49+
shift
50+
break
51+
;;
52+
*)
53+
echo "Error: unknown option '$1'" >&2
54+
usage 1
55+
;;
56+
esac
57+
done
58+
59+
if [[ $# -eq 0 ]]; then
60+
echo "Error: no command specified after --" >&2
61+
usage 1
62+
fi
63+
64+
if [[ -z "${HELION_AUTOTUNE_CHECKPOINT_DIR:-}" ]]; then
65+
echo "Error: HELION_AUTOTUNE_CHECKPOINT_DIR must be set." >&2
66+
exit 1
67+
fi
68+
69+
# --- Setup ---
70+
checkpoint_dir="$HELION_AUTOTUNE_CHECKPOINT_DIR"
71+
mkdir -p "$checkpoint_dir"
72+
73+
pending_file="$checkpoint_dir/_pending_config.txt"
74+
bad_configs_file="$checkpoint_dir/_bad_configs.txt"
75+
progress_file="$checkpoint_dir/_configs_tested.txt"
76+
77+
read_progress() {
78+
if [[ -f "$progress_file" ]]; then
79+
cat "$progress_file"
80+
else
81+
echo 0
82+
fi
83+
}
84+
85+
# --- Retry loop ---
86+
attempt=0
87+
no_progress_count=0
88+
last_progress=$(read_progress)
89+
90+
while true; do
91+
attempt=$((attempt + 1))
92+
93+
# Run the user command (don't use set -e, capture exit code manually)
94+
"$@"
95+
exit_code=$?
96+
97+
if [[ $exit_code -eq 0 ]]; then
98+
exit 0
99+
fi
100+
101+
# Check if the autotuner left a pending config breadcrumb
102+
if [[ -f "$pending_file" ]]; then
103+
config=$(cat "$pending_file")
104+
rm -f "$pending_file"
105+
echo "$config" >> "$bad_configs_file"
106+
107+
# Check progress: did the autotuner test any configs before crashing?
108+
current_progress=$(read_progress)
109+
if [[ "$current_progress" -gt "$last_progress" ]]; then
110+
configs_tested=$((current_progress - last_progress))
111+
echo "[crash-recovery] Process crashed (exit code $exit_code, attempt $attempt). Tested $configs_tested config(s) before crash." >&2
112+
no_progress_count=0
113+
last_progress=$current_progress
114+
else
115+
no_progress_count=$((no_progress_count + 1))
116+
echo "[crash-recovery] Process crashed (exit code $exit_code, attempt $attempt). No configs tested before crash ($no_progress_count/$MAX_NO_PROGRESS consecutive)." >&2
117+
fi
118+
echo "[crash-recovery] Blocked config: $config" >&2
119+
120+
if [[ $no_progress_count -ge $MAX_NO_PROGRESS ]]; then
121+
echo "[crash-recovery] No progress after $MAX_NO_PROGRESS consecutive crashes — the autotuner appears stuck." >&2
122+
echo "[crash-recovery] All bad configs have been recorded. You can re-run this script and it will resume from the latest checkpoint, skipping all previously recorded bad configs." >&2
123+
exit 1
124+
fi
125+
126+
echo "[crash-recovery] Restarting from checkpoint..." >&2
127+
else
128+
# No pending file — this is not a recoverable CUDA crash.
129+
# Propagate the original exit code.
130+
exit "$exit_code"
131+
fi
132+
done

0 commit comments

Comments
 (0)