1212import math
1313from math import inf
1414import os
15+ from pathlib import Path
16+ import pickle
1517import pprint
1618import random
1719import re
@@ -343,6 +345,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
343345 self .args : Sequence [object ] = args
344346 self .log = AutotuningLogger (self .settings )
345347 self .best_perf_so_far = inf
348+ self ._current_generation = 0
346349 self ._prepared = False
347350 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
348351 self ._precompile_args_path : str | None = None
@@ -406,6 +409,91 @@ def cleanup(self) -> None:
406409 self ._precompile_args_path = None
407410 self ._precompile_result_counter = count ()
408411
412+ # Fields excluded from pickle checkpoints: unpicklable infrastructure,
413+ # fields recomputed by _prepare(), and fields loaded separately.
414+ _CHECKPOINT_EXCLUDE = frozenset (
415+ {
416+ # Unpicklable infrastructure
417+ "kernel" ,
418+ "args" ,
419+ "log" ,
420+ "settings" ,
421+ "config_spec" ,
422+ "_precompile_tmpdir" ,
423+ "_precompile_args_path" ,
424+ "_precompile_result_counter" ,
425+ # Recomputed by _prepare() before checkpoint load
426+ "_baseline_output" ,
427+ "_baseline_post_args" ,
428+ "_mutated_arg_indices" ,
429+ "_effective_atol" ,
430+ "_effective_rtol" ,
431+ "_jobs" ,
432+ "_autotune_metrics" ,
433+ "_prepared" ,
434+ "_skip_cache" ,
435+ # Loaded separately via _load_crashed_configs()
436+ "_crashed_config_strs" ,
437+ }
438+ )
439+
440+ def __getstate__ (self ) -> dict [str , Any ]:
441+ return {
442+ k : v for k , v in self .__dict__ .items () if k not in self ._CHECKPOINT_EXCLUDE
443+ }
444+
445+ _stable_hash : str | None = None
446+
447+ def _get_stable_hash (self ) -> str :
448+ """Get the full stable hash for this kernel's cache key (cached)."""
449+ if self ._stable_hash is None :
450+ from .local_cache import LocalAutotuneCache
451+
452+ self ._stable_hash = LocalAutotuneCache (self )._generate_key ().stable_hash ()
453+ return self ._stable_hash
454+
455+ def _try_load_checkpoint (self ) -> bool :
456+ """Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
457+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
458+ if checkpoint_dir_str is None :
459+ return False
460+
461+ checkpoint_dir = Path (checkpoint_dir_str )
462+ stable_hash = self ._get_stable_hash ()
463+ checkpoint_file = checkpoint_dir / f"{ stable_hash } .pt"
464+
465+ if not checkpoint_file .exists ():
466+ return False # No matching checkpoint; start fresh
467+
468+ # Matching file exists, attempt to load
469+ self .log (f"Resuming from checkpoint: { checkpoint_file } " )
470+ try :
471+ with open (checkpoint_file , "rb" ) as f :
472+ loaded = pickle .load (f )
473+ except Exception as e :
474+ raise exc .CheckpointError (
475+ f"Failed to load checkpoint file '{ checkpoint_file } ': { e } \n "
476+ f"The file may be corrupted. Delete it to start fresh."
477+ ) from e
478+
479+ # Validate stable hash matches (guards against renamed/copied files)
480+ loaded_hash = getattr (loaded , "_stable_hash" , None )
481+ if loaded_hash is not None and loaded_hash != self ._get_stable_hash ():
482+ raise exc .CheckpointError (
483+ "Checkpoint is incompatible: kernel, hardware, or input shapes "
484+ "may have changed."
485+ )
486+
487+ # Copy loaded search state into self (self already has kernel, args,
488+ # log, etc. from __init__ and _prepare())
489+ self .__dict__ .update (loaded .__dict__ )
490+
491+ self ._recompile_after_checkpoint ()
492+ self .log (f"Resumed at generation { self ._current_generation } " )
493+ return True
494+
495+ def _recompile_after_checkpoint (self ) -> None :
496+ """Recompile after loading a checkpoint. Override in subclasses."""
409497 def _compute_baseline (
410498 self ,
411499 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -1089,8 +1177,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10891177 torch .save (self .args , args_path )
10901178 self ._precompile_args_path = args_path
10911179 exit_stack .callback (self .cleanup )
1180+
1181+ if not self ._try_load_checkpoint ():
1182+ self ._init_search ()
10921183 try :
10931184 best = self ._autotune ()
1185+ self ._cleanup_checkpoint ()
10941186 finally :
10951187 self ._finalize_autotune_metrics ()
10961188 end = time .perf_counter ()
@@ -1112,6 +1204,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11121204 print (triton_code , file = sys .stderr )
11131205 return best
11141206
1207+ def _init_search (self ) -> None :
1208+ """
1209+ Initialize the search state for a fresh autotuning run.
1210+
1211+ This method is called when starting autotuning without a checkpoint.
1212+ Subclasses should override this to set up initial population and state.
1213+ After this method, _current_generation should be set to the generation
1214+ that _autotune() should start its loop from.
1215+ """
1216+
11151217 def _autotune (self ) -> Config :
11161218 """
11171219 Abstract method to perform the actual autotuning.
@@ -1123,6 +1225,68 @@ def _autotune(self) -> Config:
11231225 """
11241226 raise NotImplementedError
11251227
1228+ def save_checkpoint (self ) -> Path | None :
1229+ """
1230+ Save current autotuner state to checkpoint file.
1231+
1232+ Only saves when autotune_checkpoint_dir is set (opt-in).
1233+ Overwrites the same file each generation (keyed by stable hash).
1234+ Uses pickle to serialize the entire autotuner object (minus unpicklable
1235+ fields excluded by __getstate__).
1236+
1237+ Returns:
1238+ Path to saved checkpoint file, or None if not saved
1239+ """
1240+ from ..runtime .kernel import BoundKernel
1241+
1242+ # External kernels don't support caching/checkpointing
1243+ if not isinstance (self .kernel , BoundKernel ):
1244+ return None
1245+
1246+ if not self .kernel .is_cacheable ():
1247+ return None
1248+
1249+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1250+ if checkpoint_dir_str is None :
1251+ return None # Opt-in: no dir set, no saving
1252+
1253+ stable_hash = self ._get_stable_hash ()
1254+ checkpoint_dir = Path (checkpoint_dir_str )
1255+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
1256+ checkpoint_path = checkpoint_dir / f"{ stable_hash } .pt"
1257+
1258+ # Atomic write using temp file + rename
1259+ tmp = checkpoint_dir / f".tmp.{ stable_hash } .{ os .getpid ()} "
1260+ with open (tmp , "wb" ) as f :
1261+ pickle .dump (self , f )
1262+ os .replace (tmp , checkpoint_path )
1263+
1264+ self .log (f"Checkpoint saved: { checkpoint_path } " )
1265+ return checkpoint_path
1266+
1267+ def _cleanup_checkpoint (self ) -> None :
1268+ """Delete checkpoint file on successful autotune completion.
1269+
1270+ Checkpoints are ephemeral in-progress state. Once autotuning
1271+ completes successfully, the result is cached normally and the
1272+ checkpoint is no longer needed.
1273+ """
1274+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1275+ if checkpoint_dir_str is None :
1276+ return
1277+
1278+ stable_hash = self ._get_stable_hash ()
1279+ checkpoint_file = Path (checkpoint_dir_str ) / f"{ stable_hash } .pt"
1280+ if checkpoint_file .exists ():
1281+ checkpoint_file .unlink ()
1282+ self .log (f"Checkpoint cleaned up: { checkpoint_file } " )
1283+
1284+ # Clean up crash-recovery artifacts
1285+ for suffix in (".pending_config" , ".crashed_configs" ):
1286+ artifact = Path (checkpoint_dir_str ) / f"{ stable_hash } { suffix } "
1287+ if artifact .exists ():
1288+ artifact .unlink ()
1289+
11261290 def set_generation (self , generation : int ) -> None :
11271291 self ._autotune_metrics .num_generations = generation
11281292
@@ -1177,6 +1341,15 @@ class PopulationMember:
11771341 def perf (self ) -> float :
11781342 return self .perfs [- 1 ]
11791343
1344+ def __getstate__ (self ) -> dict [str , Any ]:
1345+ state = self .__dict__ .copy ()
1346+ state ["fn" ] = None # compiled functions are not picklable
1347+ return state
1348+
1349+ def __setstate__ (self , state : dict [str , Any ]) -> None :
1350+ self .__dict__ .update (state )
1351+ self .fn = _unset_fn
1352+
11801353
11811354def performance (member : PopulationMember ) -> float :
11821355 """
@@ -1570,6 +1743,14 @@ def rebenchmark_population(
15701743 members = self .population
15711744 self .rebenchmark ([p for p in members if self .should_rebenchmark (p )], desc = desc )
15721745
1746+ def set_generation (self , generation : int ) -> None :
1747+ if generation == self ._current_generation :
1748+ return
1749+ self ._current_generation = generation
1750+ super ().set_generation (generation )
1751+ if generation > 0 :
1752+ self .save_checkpoint ()
1753+
15731754 def statistics (self ) -> str :
15741755 """
15751756 Generate statistics for the current population.
@@ -1579,6 +1760,27 @@ def statistics(self) -> str:
15791760 """
15801761 return population_statistics (self .population )
15811762
1763+ def _recompile_after_checkpoint (self ) -> None :
1764+ """Recompile kernel functions for population members after checkpoint load."""
1765+ recompile_failures : list [tuple [PopulationMember , str ]] = []
1766+ for member in self .population :
1767+ if member .fn is _unset_fn and member .status == "ok" :
1768+ try :
1769+ member .fn = self .kernel .compile_config (
1770+ member .config , allow_print = False
1771+ )
1772+ except Exception as e :
1773+ member .fn = _unset_fn
1774+ member .status = "error"
1775+ member .perfs .append (inf ) # Ensure member won't be selected as best
1776+ recompile_failures .append ((member , str (e )))
1777+
1778+ if recompile_failures :
1779+ self .log (
1780+ f"Warning: { len (recompile_failures )} config(s) failed to recompile "
1781+ f"and will be skipped. First failure: { recompile_failures [0 ][1 ]} "
1782+ )
1783+
15821784 def run_finishing_phase (
15831785 self , best : PopulationMember , rounds : int
15841786 ) -> PopulationMember :
0 commit comments