@@ -358,6 +358,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
358358 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
359359 self ._precompile_args_path : str | None = None
360360 self ._precompile_result_counter = count ()
361+ self ._crashed_config_strs : set [str ] = set ()
361362
362363 def _prepare (self ) -> None :
363364 """Some initialization deferred until autotuning actually runs.
@@ -501,6 +502,32 @@ def _try_load_checkpoint(self) -> bool:
501502
502503 def _recompile_after_checkpoint (self ) -> None :
503504 """Recompile after loading a checkpoint. Override in subclasses."""
505+
506+ def _load_crashed_configs (self ) -> None :
507+ """Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
508+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
509+ if checkpoint_dir_str is None :
510+ return
511+ crashed_configs_path = (
512+ Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .crashed_configs"
513+ )
514+ if crashed_configs_path .exists ():
515+ self ._crashed_config_strs |= {
516+ line .strip ()
517+ for line in crashed_configs_path .read_text ().splitlines ()
518+ if line .strip ()
519+ }
520+ if self ._crashed_config_strs :
521+ self .log (
522+ f"Loaded { len (self ._crashed_config_strs )} crashed config(s) to skip"
523+ )
524+
525+ def _get_pending_config_path (self ) -> Path | None :
526+ """Get path for pending-config sentinel, or None if checkpointing disabled."""
527+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
528+ if checkpoint_dir_str is None :
529+ return None
530+ return Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .pending_config"
504531 def _compute_baseline (
505532 self ,
506533 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -723,6 +750,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
723750 Returns:
724751 The performance of the configuration in ms.
725752 """
753+ # Skip configs that previously crashed the subprocess
754+ config_str = str (config )
755+ if config_str in self ._crashed_config_strs :
756+ self .log .warning (f"Skipping known-crashed config: { config } " )
757+ return inf
758+
726759 self ._autotune_metrics .num_configs_tested += 1
727760 self .counters ["benchmark" ] += 1
728761 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -1003,13 +1036,36 @@ def _benchmark(
10031036 A list of BenchmarkResult entries containing the configuration, compiled
10041037 callable, measured performance, status, and compilation time.
10051038 """
1039+ # Filter out known-crashed configs before compilation
1040+ if self ._crashed_config_strs :
1041+ original_len = len (configs )
1042+ configs = [c for c in configs if str (c ) not in self ._crashed_config_strs ]
1043+ skipped = original_len - len (configs )
1044+ if skipped :
1045+ self .log .warning (
1046+ f"Skipped { skipped } known-crashed config(s) before compilation"
1047+ )
1048+ if not configs :
1049+ return []
1050+
10061051 fns : list [Callable [..., object ]] = []
10071052 valid_configs : list [Config ] = []
10081053 futures : list [PrecompileFuture ] | None = None
1054+ pending_path = self ._get_pending_config_path ()
10091055 for i , config in enumerate (configs ):
1056+ # Write sentinel before compile so a hard crash (SIGKILL /
1057+ # CUDA IMA) leaves a trace the crash recovery script can find.
1058+ if pending_path is not None :
1059+ pending_path .write_text (str (config ))
10101060 try :
10111061 fn = self .kernel .compile_config (config , allow_print = False )
1012- except Exception :
1062+ except Exception as e :
1063+ if match_unrecoverable_runtime_error (e ):
1064+ # Leave sentinel for crash recovery — CUDA context is
1065+ # corrupted and the process cannot continue.
1066+ raise
1067+ if pending_path is not None :
1068+ pending_path .unlink (missing_ok = True )
10131069 # If all configs failed, raise error
10141070 if not valid_configs and i == len (configs ) - 1 :
10151071 raise
@@ -1019,9 +1075,14 @@ def _benchmark(
10191075 exc_info = True ,
10201076 )
10211077 continue
1078+ if pending_path is not None :
1079+ pending_path .unlink (missing_ok = True )
10221080 fns .append (fn )
10231081 valid_configs .append (config )
10241082 configs = valid_configs
1083+ # NOTE: precompile runs in separate subprocesses with isolated CUDA
1084+ # contexts; crashes there are caught via is_working checks, not
1085+ # sentinels.
10251086 if self .settings .autotune_precompile :
10261087 futures = list (
10271088 starmap (
@@ -1083,7 +1144,14 @@ def _benchmark(
10831144 )
10841145 )
10851146 # benchmark one-by-one to avoid noisy results
1147+ # Write pending-config sentinel; cleared after benchmark.
1148+ # On crash the file stays so the crash recovery script can
1149+ # detect which config caused the failure.
1150+ if pending_path is not None :
1151+ pending_path .write_text (str (config ))
10861152 perf = self .benchmark_function (config , fn )
1153+ if pending_path is not None :
1154+ pending_path .unlink (missing_ok = True )
10871155 status = "ok" if math .isfinite (perf ) else "error"
10881156 # Log completion after benchmarking
10891157 self .log .record_autotune_entry (
@@ -1188,6 +1256,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11881256
11891257 if not self ._try_load_checkpoint ():
11901258 self ._init_search ()
1259+ self ._load_crashed_configs ()
11911260 try :
11921261 best = self ._autotune ()
11931262 self ._cleanup_checkpoint ()
0 commit comments