@@ -350,6 +350,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
350350 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
351351 self ._precompile_args_path : str | None = None
352352 self ._precompile_result_counter = count ()
353+ self ._crashed_config_strs : set [str ] = set ()
353354
354355 def _prepare (self ) -> None :
355356 """Some initialization deferred until autotuning actually runs.
@@ -494,6 +495,32 @@ def _try_load_checkpoint(self) -> bool:
494495
495496 def _recompile_after_checkpoint (self ) -> None :
496497 """Recompile after loading a checkpoint. Override in subclasses."""
498+
499+ def _load_crashed_configs (self ) -> None :
500+ """Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
501+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
502+ if checkpoint_dir_str is None :
503+ return
504+ crashed_configs_path = (
505+ Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .crashed_configs"
506+ )
507+ if crashed_configs_path .exists ():
508+ self ._crashed_config_strs |= {
509+ line .strip ()
510+ for line in crashed_configs_path .read_text ().splitlines ()
511+ if line .strip ()
512+ }
513+ if self ._crashed_config_strs :
514+ self .log (
515+ f"Loaded { len (self ._crashed_config_strs )} crashed config(s) to skip"
516+ )
517+
518+ def _get_pending_config_path (self ) -> Path | None :
519+ """Get path for pending-config sentinel, or None if checkpointing disabled."""
520+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
521+ if checkpoint_dir_str is None :
522+ return None
523+ return Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .pending_config"
497524 def _compute_baseline (
498525 self ,
499526 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -716,6 +743,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
716743 Returns:
717744 The performance of the configuration in ms.
718745 """
746+ # Skip configs that previously crashed the subprocess
747+ config_str = str (config )
748+ if config_str in self ._crashed_config_strs :
749+ self .log .warning (f"Skipping known-crashed config: { config } " )
750+ return inf
751+
719752 self ._autotune_metrics .num_configs_tested += 1
720753 self .log .debug (lambda : f"Running benchmark for { config !r} " )
721754 _captured_output : list [str ] = ["" ]
@@ -995,13 +1028,36 @@ def _benchmark(
9951028 A list of BenchmarkResult entries containing the configuration, compiled
9961029 callable, measured performance, status, and compilation time.
9971030 """
1031+ # Filter out known-crashed configs before compilation
1032+ if self ._crashed_config_strs :
1033+ original_len = len (configs )
1034+ configs = [c for c in configs if str (c ) not in self ._crashed_config_strs ]
1035+ skipped = original_len - len (configs )
1036+ if skipped :
1037+ self .log .warning (
1038+ f"Skipped { skipped } known-crashed config(s) before compilation"
1039+ )
1040+ if not configs :
1041+ return []
1042+
9981043 fns : list [Callable [..., object ]] = []
9991044 valid_configs : list [Config ] = []
10001045 futures : list [PrecompileFuture ] | None = None
1046+ pending_path = self ._get_pending_config_path ()
10011047 for i , config in enumerate (configs ):
1048+ # Write sentinel before compile so a hard crash (SIGKILL /
1049+ # CUDA IMA) leaves a trace the crash recovery script can find.
1050+ if pending_path is not None :
1051+ pending_path .write_text (str (config ))
10021052 try :
10031053 fn = self .kernel .compile_config (config , allow_print = False )
1004- except Exception :
1054+ except Exception as e :
1055+ if match_unrecoverable_runtime_error (e ):
1056+ # Leave sentinel for crash recovery — CUDA context is
1057+ # corrupted and the process cannot continue.
1058+ raise
1059+ if pending_path is not None :
1060+ pending_path .unlink (missing_ok = True )
10051061 # If all configs failed, raise error
10061062 if not valid_configs and i == len (configs ) - 1 :
10071063 raise
@@ -1011,9 +1067,14 @@ def _benchmark(
10111067 exc_info = True ,
10121068 )
10131069 continue
1070+ if pending_path is not None :
1071+ pending_path .unlink (missing_ok = True )
10141072 fns .append (fn )
10151073 valid_configs .append (config )
10161074 configs = valid_configs
1075+ # NOTE: precompile runs in separate subprocesses with isolated CUDA
1076+ # contexts; crashes there are caught via is_working checks, not
1077+ # sentinels.
10171078 if self .settings .autotune_precompile :
10181079 futures = list (
10191080 starmap (
@@ -1075,7 +1136,14 @@ def _benchmark(
10751136 )
10761137 )
10771138 # benchmark one-by-one to avoid noisy results
1139+ # Write pending-config sentinel; cleared after benchmark.
1140+ # On crash the file stays so the crash recovery script can
1141+ # detect which config caused the failure.
1142+ if pending_path is not None :
1143+ pending_path .write_text (str (config ))
10781144 perf = self .benchmark_function (config , fn )
1145+ if pending_path is not None :
1146+ pending_path .unlink (missing_ok = True )
10791147 status = "ok" if math .isfinite (perf ) else "error"
10801148 # Log completion after benchmarking
10811149 self .log .record_autotune_entry (
@@ -1180,6 +1248,7 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11801248
11811249 if not self ._try_load_checkpoint ():
11821250 self ._init_search ()
1251+ self ._load_crashed_configs ()
11831252 try :
11841253 best = self ._autotune ()
11851254 self ._cleanup_checkpoint ()
0 commit comments