1212import math
1313from math import inf
1414import os
15+ from pathlib import Path
16+ import pickle
1517import pprint
1618import random
1719import re
@@ -407,6 +409,114 @@ def cleanup(self) -> None:
407409 self ._precompile_args_path = None
408410 self ._precompile_result_counter = count ()
409411
412+ # Fields excluded from pickle checkpoints: unpicklable infrastructure,
413+ # fields recomputed by _prepare(), and fields loaded separately.
414+ _CHECKPOINT_EXCLUDE = frozenset (
415+ {
416+ # Unpicklable infrastructure
417+ "kernel" ,
418+ "args" ,
419+ "log" ,
420+ "settings" ,
421+ "config_spec" ,
422+ "_precompile_tmpdir" ,
423+ "_precompile_args_path" ,
424+ "_precompile_result_counter" ,
425+ # Recomputed by _prepare() before checkpoint load
426+ "_baseline_output" ,
427+ "_baseline_post_args" ,
428+ "_mutated_arg_indices" ,
429+ "_effective_atol" ,
430+ "_effective_rtol" ,
431+ "_jobs" ,
432+ "_autotune_metrics" ,
433+ "_prepared" ,
434+ "_skip_cache" ,
435+ # Loaded separately via _load_crashed_configs()
436+ "_crashed_config_strs" ,
437+ }
438+ )
439+
440+ def __getstate__ (self ) -> dict [str , Any ]:
441+ return {
442+ k : v for k , v in self .__dict__ .items () if k not in self ._CHECKPOINT_EXCLUDE
443+ }
444+
445+ _stable_hash : str | None = None
446+
447+ def _get_stable_hash (self ) -> str :
448+ """Get the full stable hash for this kernel's cache key (cached)."""
449+ if self ._stable_hash is None :
450+ from .local_cache import LocalAutotuneCache
451+
452+ self ._stable_hash = LocalAutotuneCache (self )._generate_key ().stable_hash ()
453+ return self ._stable_hash
454+
455+ def _try_load_checkpoint (self ) -> bool :
456+ """Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
457+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
458+ if checkpoint_dir_str is None :
459+ return False
460+
461+ checkpoint_dir = Path (checkpoint_dir_str )
462+ stable_hash = self ._get_stable_hash ()
463+ checkpoint_file = checkpoint_dir / f"{ stable_hash } .pt"
464+
465+ if not checkpoint_file .exists ():
466+ return False # No matching checkpoint; start fresh
467+
468+ # Matching file exists, attempt to load
469+ self .log (f"Resuming from checkpoint: { checkpoint_file } " )
470+ try :
471+ with open (checkpoint_file , "rb" ) as f :
472+ loaded = pickle .load (f )
473+ except Exception as e :
474+ raise exc .CheckpointError (
475+ f"Failed to load checkpoint file '{ checkpoint_file } ': { e } \n "
476+ f"The file may be corrupted. Delete it to start fresh."
477+ ) from e
478+
479+ # Validate stable hash matches (guards against renamed/copied files)
480+ loaded_hash = getattr (loaded , "_stable_hash" , None )
481+ if loaded_hash is not None and loaded_hash != self ._get_stable_hash ():
482+ raise exc .CheckpointError (
483+ "Checkpoint is incompatible: kernel, hardware, or input shapes "
484+ "may have changed."
485+ )
486+
487+ # Copy loaded search state into self (self already has kernel, args,
488+ # log, etc. from __init__ and _prepare())
489+ self .__dict__ .update (loaded .__dict__ )
490+
491+ self .log (f"Resumed at generation { self ._current_generation } " )
492+ return True
493+
494+ def _load_crashed_configs (self ) -> None :
495+ """Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
496+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
497+ if checkpoint_dir_str is None :
498+ return
499+ crashed_configs_path = (
500+ Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .crashed_configs"
501+ )
502+ if crashed_configs_path .exists ():
503+ self ._crashed_config_strs |= {
504+ line .strip ()
505+ for line in crashed_configs_path .read_text ().splitlines ()
506+ if line .strip ()
507+ }
508+ if self ._crashed_config_strs :
509+ self .log (
510+ f"Loaded { len (self ._crashed_config_strs )} crashed config(s) to skip"
511+ )
512+
513+ def _get_pending_config_path (self ) -> Path | None :
514+ """Get path for pending-config sentinel, or None if checkpointing disabled."""
515+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
516+ if checkpoint_dir_str is None :
517+ return None
518+ return Path (checkpoint_dir_str ) / f"{ self ._get_stable_hash ()} .pending_config"
519+
410520 def _compute_baseline (
411521 self ,
412522 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -1091,9 +1201,13 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10911201 self ._precompile_args_path = args_path
10921202 exit_stack .callback (self .cleanup )
10931203
1094- self ._init_search ()
1204+ checkpoint_enabled = self .settings .autotune_checkpoint_dir is not None
1205+ if not (checkpoint_enabled and self ._try_load_checkpoint ()):
1206+ self ._init_search ()
10951207 try :
10961208 best = self ._autotune ()
1209+ if checkpoint_enabled :
1210+ self ._cleanup_checkpoint ()
10971211 finally :
10981212 self ._finalize_autotune_metrics ()
10991213 end = time .perf_counter ()
@@ -1119,6 +1233,7 @@ def _init_search(self) -> None:
11191233 """
11201234 Initialize the search state for a fresh autotuning run.
11211235
1236+ This method is called when starting autotuning without a checkpoint.
11221237 Subclasses should override this to set up initial population and state.
11231238 After this method, _current_generation should be set to the generation
11241239 that _autotune() should start its loop from.
@@ -1135,6 +1250,68 @@ def _autotune(self) -> Config:
11351250 """
11361251 raise NotImplementedError
11371252
1253+ def save_checkpoint (self ) -> Path | None :
1254+ """
1255+ Save current autotuner state to checkpoint file.
1256+
1257+ Only saves when autotune_checkpoint_dir is set (opt-in).
1258+ Overwrites the same file each generation (keyed by stable hash).
1259+ Uses pickle to serialize the entire autotuner object (minus unpicklable
1260+ fields excluded by __getstate__).
1261+
1262+ Returns:
1263+ Path to saved checkpoint file, or None if not saved
1264+ """
1265+ from ..runtime .kernel import BoundKernel
1266+
1267+ # External kernels don't support caching/checkpointing
1268+ if not isinstance (self .kernel , BoundKernel ):
1269+ return None
1270+
1271+ if not self .kernel .is_cacheable ():
1272+ return None
1273+
1274+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1275+ if checkpoint_dir_str is None :
1276+ return None # Opt-in: no dir set, no saving
1277+
1278+ stable_hash = self ._get_stable_hash ()
1279+ checkpoint_dir = Path (checkpoint_dir_str )
1280+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
1281+ checkpoint_path = checkpoint_dir / f"{ stable_hash } .pt"
1282+
1283+ # Atomic write using temp file + rename
1284+ tmp = checkpoint_dir / f".tmp.{ stable_hash } .{ os .getpid ()} "
1285+ with open (tmp , "wb" ) as f :
1286+ pickle .dump (self , f )
1287+ os .replace (tmp , checkpoint_path )
1288+
1289+ self .log (f"Checkpoint saved: { checkpoint_path } " )
1290+ return checkpoint_path
1291+
1292+ def _cleanup_checkpoint (self ) -> None :
1293+ """Delete checkpoint file on successful autotune completion.
1294+
1295+ Checkpoints are ephemeral in-progress state. Once autotuning
1296+ completes successfully, the result is cached normally and the
1297+ checkpoint is no longer needed.
1298+ """
1299+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1300+ if checkpoint_dir_str is None :
1301+ return
1302+
1303+ stable_hash = self ._get_stable_hash ()
1304+ checkpoint_file = Path (checkpoint_dir_str ) / f"{ stable_hash } .pt"
1305+ if checkpoint_file .exists ():
1306+ checkpoint_file .unlink ()
1307+ self .log (f"Checkpoint cleaned up: { checkpoint_file } " )
1308+
1309+ # Clean up crash-recovery artifacts
1310+ for suffix in (".pending_config" , ".crashed_configs" ):
1311+ artifact = Path (checkpoint_dir_str ) / f"{ stable_hash } { suffix } "
1312+ if artifact .exists ():
1313+ artifact .unlink ()
1314+
11381315 def set_generation (self , generation : int ) -> None :
11391316 self ._autotune_metrics .num_generations = generation
11401317
@@ -1189,6 +1366,15 @@ class PopulationMember:
11891366 def perf (self ) -> float :
11901367 return self .perfs [- 1 ]
11911368
1369+ def __getstate__ (self ) -> dict [str , Any ]:
1370+ state = self .__dict__ .copy ()
1371+ state ["fn" ] = None # compiled functions are not picklable
1372+ return state
1373+
1374+ def __setstate__ (self , state : dict [str , Any ]) -> None :
1375+ self .__dict__ .update (state )
1376+ self .fn = _unset_fn
1377+
11921378
11931379def performance (member : PopulationMember ) -> float :
11941380 """
@@ -1587,6 +1773,8 @@ def set_generation(self, generation: int) -> None:
15871773 return
15881774 self ._current_generation = generation
15891775 super ().set_generation (generation )
1776+ if generation > 0 and self .settings .autotune_checkpoint_dir is not None :
1777+ self .save_checkpoint ()
15901778
15911779 def statistics (self ) -> str :
15921780 """
@@ -1597,6 +1785,30 @@ def statistics(self) -> str:
15971785 """
15981786 return population_statistics (self .population )
15991787
1788+ def _try_load_checkpoint (self ) -> bool :
1789+ if not super ()._try_load_checkpoint ():
1790+ return False
1791+ # Recompile kernel functions for population members after checkpoint load
1792+ recompile_failures : list [tuple [PopulationMember , str ]] = []
1793+ for member in self .population :
1794+ if member .fn is _unset_fn and member .status == "ok" :
1795+ try :
1796+ member .fn = self .kernel .compile_config (
1797+ member .config , allow_print = False
1798+ )
1799+ except Exception as e :
1800+ member .fn = _unset_fn
1801+ member .status = "error"
1802+ member .perfs .append (inf ) # Ensure member won't be selected as best
1803+ recompile_failures .append ((member , str (e )))
1804+
1805+ if recompile_failures :
1806+ self .log (
1807+ f"Warning: { len (recompile_failures )} config(s) failed to recompile "
1808+ f"and will be skipped. First failure: { recompile_failures [0 ][1 ]} "
1809+ )
1810+ return True
1811+
16001812 def run_finishing_phase (
16011813 self , best : PopulationMember , rounds : int
16021814 ) -> PopulationMember :
0 commit comments