@@ -358,6 +358,7 @@ class BaseSearch(BaseAutotuner):
358358 "_skip_cache" ,
359359 "_autotune_metrics" ,
360360 "_stable_hash" ,
361+ "_crashed_config_strs" ,
361362 )
362363
363364 @classmethod
@@ -404,6 +405,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
404405 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
405406 self ._precompile_args_path : str | None = None
406407 self ._precompile_result_counter = count ()
408+ self ._crashed_config_strs : set [str ] = set ()
407409
408410 def _prepare (self ) -> None :
409411 """Some initialization deferred until autotuning actually runs.
@@ -503,6 +505,32 @@ def _try_load_checkpoint(self) -> bool:
503505 self .log (f"Resumed at generation { self ._current_generation } " )
504506 return True
505507
508+ def _load_crashed_configs (self ) -> None :
509+ """Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
510+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
511+ if checkpoint_dir_str is None :
512+ return
513+ crashed_configs_path = (
514+ Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .crashed_configs"
515+ )
516+ if crashed_configs_path .exists ():
517+ self ._crashed_config_strs |= {
518+ line .strip ()
519+ for line in crashed_configs_path .read_text ().splitlines ()
520+ if line .strip ()
521+ }
522+ if self ._crashed_config_strs :
523+ self .log (
524+ f"Loaded { len (self ._crashed_config_strs )} crashed config(s) to skip"
525+ )
526+
527+ def _get_pending_config_path (self ) -> Path | None :
528+ """Get path for pending-config sentinel, or None if checkpointing disabled."""
529+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
530+ if checkpoint_dir_str is None :
531+ return None
532+ return Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .pending_config"
533+
506534 def _compute_baseline (
507535 self ,
508536 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -725,6 +753,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
725753 Returns:
726754 The performance of the configuration in ms.
727755 """
756+ # Skip configs that previously crashed the subprocess
757+ config_str = str (config )
758+ if config_str in self ._crashed_config_strs :
759+ self .log .warning (f"Skipping known-crashed config: { config } " )
760+ return inf
761+
728762 self ._autotune_metrics .num_configs_tested += 1
729763 self .counters ["benchmark" ] += 1
730764 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -1005,13 +1039,36 @@ def _benchmark(
10051039 A list of BenchmarkResult entries containing the configuration, compiled
10061040 callable, measured performance, status, and compilation time.
10071041 """
1042+ # Filter out known-crashed configs before compilation
1043+ if self ._crashed_config_strs :
1044+ original_len = len (configs )
1045+ configs = [c for c in configs if str (c ) not in self ._crashed_config_strs ]
1046+ skipped = original_len - len (configs )
1047+ if skipped :
1048+ self .log .warning (
1049+ f"Skipped { skipped } known-crashed config(s) before compilation"
1050+ )
1051+ if not configs :
1052+ return []
1053+
10081054 fns : list [Callable [..., object ]] = []
10091055 valid_configs : list [Config ] = []
10101056 futures : list [PrecompileFuture ] | None = None
1057+ pending_path = self ._get_pending_config_path ()
10111058 for i , config in enumerate (configs ):
1059+ # Write sentinel before compile so a hard crash (SIGKILL /
1060+ # CUDA IMA) leaves a trace the crash recovery script can find.
1061+ if pending_path is not None :
1062+ pending_path .write_text (str (config ))
10121063 try :
10131064 fn = self .kernel .compile_config (config , allow_print = False )
1014- except Exception :
1065+ except Exception as e :
1066+ if match_unrecoverable_runtime_error (e ):
1067+ # Leave sentinel for crash recovery — CUDA context is
1068+ # corrupted and the process cannot continue.
1069+ raise
1070+ if pending_path is not None :
1071+ pending_path .unlink (missing_ok = True )
10151072 # If all configs failed, raise error
10161073 if not valid_configs and i == len (configs ) - 1 :
10171074 raise
@@ -1021,9 +1078,14 @@ def _benchmark(
10211078 exc_info = True ,
10221079 )
10231080 continue
1081+ if pending_path is not None :
1082+ pending_path .unlink (missing_ok = True )
10241083 fns .append (fn )
10251084 valid_configs .append (config )
10261085 configs = valid_configs
1086+ # NOTE: precompile runs in separate subprocesses with isolated CUDA
1087+ # contexts; crashes there are caught via is_working checks, not
1088+ # sentinels.
10271089 if self .settings .autotune_precompile :
10281090 futures = list (
10291091 starmap (
@@ -1085,7 +1147,14 @@ def _benchmark(
10851147 )
10861148 )
10871149 # benchmark one-by-one to avoid noisy results
1150+ # Write pending-config sentinel; cleared after benchmark.
1151+ # On crash the file stays so the crash recovery script can
1152+ # detect which config caused the failure.
1153+ if pending_path is not None :
1154+ pending_path .write_text (str (config ))
10881155 perf = self .benchmark_function (config , fn )
1156+ if pending_path is not None :
1157+ pending_path .unlink (missing_ok = True )
10891158 status = "ok" if math .isfinite (perf ) else "error"
10901159 # Log completion after benchmarking
10911160 self .log .record_autotune_entry (
@@ -1190,6 +1259,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11901259
11911260 if not self ._try_load_checkpoint ():
11921261 self ._init_search ()
1262+ self ._load_crashed_configs ()
11931263 try :
11941264 best = self ._autotune ()
11951265 self ._cleanup_checkpoint ()
@@ -1291,6 +1361,12 @@ def _cleanup_checkpoint(self) -> None:
12911361 checkpoint_file .unlink ()
12921362 self .log (f"Checkpoint cleaned up: { checkpoint_file } " )
12931363
1364+ # Clean up crash-recovery artifacts
1365+ for suffix in (".pending_config" , ".crashed_configs" ):
1366+ artifact = Path (checkpoint_dir_str ) / f"{ stable_hash } { suffix } "
1367+ if artifact .exists ():
1368+ artifact .unlink ()
1369+
12941370 @staticmethod
12951371 def _serialize_numpy_rng_state (
12961372 state : tuple [str , Any , int , int , float ],
0 commit comments