@@ -369,6 +369,8 @@ class BaseSearch(BaseAutotuner):
369369 "_skip_cache" ,
370370 "_autotune_metrics" ,
371371 "_stable_hash" ,
372+ # Loaded separately from .crashed_configs file
373+ "_crashed_config_strs" ,
372374 }
373375
374376 @classmethod
@@ -423,6 +425,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
423425 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
424426 self ._precompile_args_path : str | None = None
425427 self ._precompile_result_counter = count ()
428+ self ._crashed_config_strs : set [str ] = set ()
426429
427430 def _prepare (self ) -> None :
428431 """Some initialization deferred until autotuning actually runs.
@@ -522,6 +525,32 @@ def _try_load_checkpoint(self) -> bool:
522525 self .log (f"Resumed at generation { self ._current_generation } " )
523526 return True
524527
528+ def _load_crashed_configs (self ) -> None :
529+ """Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
530+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
531+ if checkpoint_dir_str is None :
532+ return
533+ crashed_configs_path = (
534+ Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .crashed_configs"
535+ )
536+ if crashed_configs_path .exists ():
537+ self ._crashed_config_strs |= {
538+ line .strip ()
539+ for line in crashed_configs_path .read_text ().splitlines ()
540+ if line .strip ()
541+ }
542+ if self ._crashed_config_strs :
543+ self .log (
544+ f"Loaded { len (self ._crashed_config_strs )} crashed config(s) to skip"
545+ )
546+
547+ def _get_pending_config_path (self ) -> Path | None :
548+ """Get path for pending-config sentinel, or None if checkpointing disabled."""
549+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
550+ if checkpoint_dir_str is None :
551+ return None
552+ return Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .pending_config"
553+
525554 def _compute_baseline (
526555 self ,
527556 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -744,6 +773,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
744773 Returns:
745774 The performance of the configuration in ms.
746775 """
776+ # Skip configs that previously crashed the subprocess
777+ config_str = str (config )
778+ if config_str in self ._crashed_config_strs :
779+ self .log .warning (f"Skipping known-crashed config: { config } " )
780+ return inf
781+
747782 self ._autotune_metrics .num_configs_tested += 1
748783 self .counters ["benchmark" ] += 1
749784 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -1024,13 +1059,32 @@ def _benchmark(
10241059 A list of BenchmarkResult entries containing the configuration, compiled
10251060 callable, measured performance, status, and compilation time.
10261061 """
1062+ # Filter out known-crashed configs before compilation
1063+ if self ._crashed_config_strs :
1064+ original_len = len (configs )
1065+ configs = [c for c in configs if str (c ) not in self ._crashed_config_strs ]
1066+ skipped = original_len - len (configs )
1067+ if skipped :
1068+ self .log .warning (
1069+ f"Skipped { skipped } known-crashed config(s) before compilation"
1070+ )
1071+ if not configs :
1072+ return []
1073+
10271074 fns : list [Callable [..., object ]] = []
10281075 valid_configs : list [Config ] = []
10291076 futures : list [PrecompileFuture ] | None = None
1077+ pending_path = self ._get_pending_config_path ()
10301078 for i , config in enumerate (configs ):
1079+ # Write sentinel before compile so a hard crash (SIGKILL /
1080+ # CUDA IMA) leaves a trace the crash recovery script can find.
1081+ if pending_path is not None :
1082+ pending_path .write_text (str (config ))
10311083 try :
10321084 fn = self .kernel .compile_config (config , allow_print = False )
10331085 except Exception :
1086+ if pending_path is not None :
1087+ pending_path .unlink (missing_ok = True )
10341088 # If all configs failed, raise error
10351089 if not valid_configs and i == len (configs ) - 1 :
10361090 raise
@@ -1040,9 +1094,14 @@ def _benchmark(
10401094 exc_info = True ,
10411095 )
10421096 continue
1097+ if pending_path is not None :
1098+ pending_path .unlink (missing_ok = True )
10431099 fns .append (fn )
10441100 valid_configs .append (config )
10451101 configs = valid_configs
1102+ # NOTE: precompile runs in separate subprocesses with isolated CUDA
1103+ # contexts; crashes there are caught via is_working checks, not
1104+ # sentinels.
10461105 if self .settings .autotune_precompile :
10471106 futures = list (
10481107 starmap (
@@ -1104,7 +1163,14 @@ def _benchmark(
11041163 )
11051164 )
11061165 # benchmark one-by-one to avoid noisy results
1166+ # Write pending-config sentinel; cleared after benchmark.
1167+ # On crash the file stays so the crash recovery script can
1168+ # detect which config caused the failure.
1169+ if pending_path is not None :
1170+ pending_path .write_text (str (config ))
11071171 perf = self .benchmark_function (config , fn )
1172+ if pending_path is not None :
1173+ pending_path .unlink (missing_ok = True )
11081174 status = "ok" if math .isfinite (perf ) else "error"
11091175 # Log completion after benchmarking
11101176 self .log .record_autotune_entry (
@@ -1209,6 +1275,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
12091275
12101276 if not self ._try_load_checkpoint ():
12111277 self ._init_search ()
1278+ self ._load_crashed_configs ()
12121279 try :
12131280 best = self ._autotune ()
12141281 self ._cleanup_checkpoint ()
@@ -1311,6 +1378,12 @@ def _cleanup_checkpoint(self) -> None:
13111378 checkpoint_file .unlink ()
13121379 self .log (f"Checkpoint cleaned up: { checkpoint_file } " )
13131380
1381+ # Clean up crash-recovery artifacts
1382+ for suffix in (".pending_config" , ".crashed_configs" ):
1383+ artifact = Path (checkpoint_dir_str ) / f"{ stable_hash } { suffix } "
1384+ if artifact .exists ():
1385+ artifact .unlink ()
1386+
13141387 @staticmethod
13151388 def _serialize_numpy_rng_state (
13161389 state : tuple [str , Any , int , int , float ],
0 commit comments