@@ -367,6 +367,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
367367 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
368368 self ._precompile_args_path : str | None = None
369369 self ._precompile_result_counter = count ()
370+ self ._crashed_config_strs : set [str ] = set ()
370371
371372 def _prepare (self ) -> None :
372373 """Some initialization deferred until autotuning actually runs.
@@ -471,6 +472,32 @@ def _try_load_checkpoint(self) -> bool:
471472 self .log (f"Resumed at generation { self ._current_generation } " )
472473 return True
473474
475+ def _load_crashed_configs (self ) -> None :
476+ """Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
477+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
478+ if checkpoint_dir_str is None :
479+ return
480+ crashed_configs_path = (
481+ Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .crashed_configs"
482+ )
483+ if crashed_configs_path .exists ():
484+ self ._crashed_config_strs |= {
485+ line .strip ()
486+ for line in crashed_configs_path .read_text ().splitlines ()
487+ if line .strip ()
488+ }
489+ if self ._crashed_config_strs :
490+ self .log (
491+ f"Loaded { len (self ._crashed_config_strs )} crashed config(s) to skip"
492+ )
493+
494+ def _get_pending_config_path (self ) -> Path | None :
495+ """Get path for pending-config sentinel, or None if checkpointing disabled."""
496+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
497+ if checkpoint_dir_str is None :
498+ return None
499+ return Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .pending_config"
500+
474501 def _compute_baseline (
475502 self ,
476503 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -693,6 +720,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
693720 Returns:
694721 The performance of the configuration in ms.
695722 """
723+ # Skip configs that previously crashed the subprocess
724+ config_str = str (config )
725+ if config_str in self ._crashed_config_strs :
726+ self .log .warning (f"Skipping known-crashed config: { config } " )
727+ return inf
728+
696729 self ._autotune_metrics .num_configs_tested += 1
697730 self .counters ["benchmark" ] += 1
698731 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -973,13 +1006,36 @@ def _benchmark(
9731006 A list of BenchmarkResult entries containing the configuration, compiled
9741007 callable, measured performance, status, and compilation time.
9751008 """
1009+ # Filter out known-crashed configs before compilation
1010+ if self ._crashed_config_strs :
1011+ original_len = len (configs )
1012+ configs = [c for c in configs if str (c ) not in self ._crashed_config_strs ]
1013+ skipped = original_len - len (configs )
1014+ if skipped :
1015+ self .log .warning (
1016+ f"Skipped { skipped } known-crashed config(s) before compilation"
1017+ )
1018+ if not configs :
1019+ return []
1020+
9761021 fns : list [Callable [..., object ]] = []
9771022 valid_configs : list [Config ] = []
9781023 futures : list [PrecompileFuture ] | None = None
1024+ pending_path = self ._get_pending_config_path ()
9791025 for i , config in enumerate (configs ):
1026+ # Write sentinel before compile so a hard crash (SIGKILL /
1027+ # CUDA IMA) leaves a trace the crash recovery script can find.
1028+ if pending_path is not None :
1029+ pending_path .write_text (str (config ))
9801030 try :
9811031 fn = self .kernel .compile_config (config , allow_print = False )
982- except Exception :
1032+ except Exception as e :
1033+ if match_unrecoverable_runtime_error (e ):
1034+ # Leave sentinel for crash recovery — CUDA context is
1035+ # corrupted and the process cannot continue.
1036+ raise
1037+ if pending_path is not None :
1038+ pending_path .unlink (missing_ok = True )
9831039 # If all configs failed, raise error
9841040 if not valid_configs and i == len (configs ) - 1 :
9851041 raise
@@ -989,9 +1045,14 @@ def _benchmark(
9891045 exc_info = True ,
9901046 )
9911047 continue
1048+ if pending_path is not None :
1049+ pending_path .unlink (missing_ok = True )
9921050 fns .append (fn )
9931051 valid_configs .append (config )
9941052 configs = valid_configs
1053+ # NOTE: precompile runs in separate subprocesses with isolated CUDA
1054+ # contexts; crashes there are caught via is_working checks, not
1055+ # sentinels.
9951056 if self .settings .autotune_precompile :
9961057 futures = list (
9971058 starmap (
@@ -1053,7 +1114,14 @@ def _benchmark(
10531114 )
10541115 )
10551116 # benchmark one-by-one to avoid noisy results
1117+ # Write pending-config sentinel; cleared after benchmark.
1118+ # On crash the file stays so the crash recovery script can
1119+ # detect which config caused the failure.
1120+ if pending_path is not None :
1121+ pending_path .write_text (str (config ))
10561122 perf = self .benchmark_function (config , fn )
1123+ if pending_path is not None :
1124+ pending_path .unlink (missing_ok = True )
10571125 status = "ok" if math .isfinite (perf ) else "error"
10581126 # Log completion after benchmarking
10591127 self .log .record_autotune_entry (
@@ -1158,6 +1226,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11581226
11591227 if not self ._try_load_checkpoint ():
11601228 self ._init_search ()
1229+ self ._load_crashed_configs ()
11611230 try :
11621231 best = self ._autotune ()
11631232 self ._cleanup_checkpoint ()
@@ -1259,6 +1328,12 @@ def _cleanup_checkpoint(self) -> None:
12591328 checkpoint_file .unlink ()
12601329 self .log (f"Checkpoint cleaned up: { checkpoint_file } " )
12611330
1331+ # Clean up crash-recovery artifacts
1332+ for suffix in (".pending_config" , ".crashed_configs" ):
1333+ artifact = Path (checkpoint_dir_str ) / f"{ stable_hash } { suffix } "
1334+ if artifact .exists ():
1335+ artifact .unlink ()
1336+
12621337 @staticmethod
12631338 def _serialize_numpy_rng_state (
12641339 state : tuple [str , Any , int , int , float ],
0 commit comments