1212import math
1313from math import inf
1414import os
15+ from pathlib import Path
16+ import pickle
1517import pprint
1618import random
1719import re
@@ -343,6 +345,8 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
343345 self .args : Sequence [object ] = args
344346 self .log = AutotuningLogger (self .settings )
345347 self .best_perf_so_far = inf
348+ self ._current_generation = 0
349+ self .counters : collections .Counter [str ] = collections .Counter ()
346350 self ._prepared = False
347351 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
348352 self ._precompile_args_path : str | None = None
@@ -406,6 +410,90 @@ def cleanup(self) -> None:
406410 self ._precompile_args_path = None
407411 self ._precompile_result_counter = count ()
408412
413+ # Fields excluded from pickle checkpoints: unpicklable infrastructure,
414+ # fields recomputed by _prepare(), and fields loaded separately.
415+ _CHECKPOINT_EXCLUDE = frozenset (
416+ {
417+ # Unpicklable infrastructure
418+ "kernel" ,
419+ "args" ,
420+ "log" ,
421+ "settings" ,
422+ "config_spec" ,
423+ "_precompile_tmpdir" ,
424+ "_precompile_args_path" ,
425+ "_precompile_result_counter" ,
426+ # Recomputed by _prepare() before checkpoint load
427+ "_baseline_output" ,
428+ "_baseline_post_args" ,
429+ "_mutated_arg_indices" ,
430+ "_effective_atol" ,
431+ "_effective_rtol" ,
432+ "_jobs" ,
433+ "_autotune_metrics" ,
434+ "_prepared" ,
435+ "_skip_cache" ,
436+ # Loaded separately via _load_crashed_configs()
437+ "_crashed_config_strs" ,
438+ }
439+ )
440+
441+ def __getstate__ (self ) -> dict [str , Any ]:
442+ return {
443+ k : v for k , v in self .__dict__ .items () if k not in self ._CHECKPOINT_EXCLUDE
444+ }
445+
446+ _stable_hash : str | None = None
447+
448+ def _get_stable_hash (self ) -> str :
449+ """Get the full stable hash for this kernel's cache key (cached)."""
450+ if self ._stable_hash is None :
451+ from .local_cache import LocalAutotuneCache
452+
453+ self ._stable_hash = LocalAutotuneCache (self )._generate_key ().stable_hash ()
454+ return self ._stable_hash
455+
456+ def _try_load_checkpoint (self ) -> bool :
457+ """Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
458+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
459+ if checkpoint_dir_str is None :
460+ return False
461+
462+ checkpoint_dir = Path (checkpoint_dir_str )
463+ stable_hash = self ._get_stable_hash ()
464+ checkpoint_file = checkpoint_dir / f"{ stable_hash } .pt"
465+
466+ if not checkpoint_file .exists ():
467+ return False # No matching checkpoint; start fresh
468+
469+ # Matching file exists, attempt to load
470+ self .log (f"Resuming from checkpoint: { checkpoint_file } " )
471+ try :
472+ with open (checkpoint_file , "rb" ) as f :
473+ loaded = pickle .load (f )
474+ except Exception as e :
475+ raise exc .CheckpointError (
476+ f"Failed to load checkpoint file '{ checkpoint_file } ': { e } \n "
477+ f"The file may be corrupted. Delete it to start fresh."
478+ ) from e
479+
480+ # Validate stable hash matches (guards against renamed/copied files)
481+ loaded_hash = getattr (loaded , "_stable_hash" , None )
482+ if loaded_hash is not None and loaded_hash != self ._get_stable_hash ():
483+ raise exc .CheckpointError (
484+ "Checkpoint is incompatible: kernel, hardware, or input shapes "
485+ "may have changed."
486+ )
487+
488+ # Copy loaded search state into self (self already has kernel, args,
489+ # log, etc. from __init__ and _prepare())
490+ self .__dict__ .update (loaded .__dict__ )
491+ self ._recompile_after_checkpoint ()
492+ self .log (f"Resumed at generation { self ._current_generation } " )
493+ return True
494+
495+ def _recompile_after_checkpoint (self ) -> None :
496+ """Recompile after loading a checkpoint. Override in subclasses."""
409497 def _compute_baseline (
410498 self ,
411499 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -629,6 +717,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
629717 The performance of the configuration in ms.
630718 """
631719 self ._autotune_metrics .num_configs_tested += 1
720+ self .counters ["benchmark" ] += 1
632721 self .log .debug (lambda : f"Running benchmark for { config !r} " )
633722 _captured_output : list [str ] = ["" ]
634723 _capture_ctx = (
@@ -1089,8 +1178,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10891178 torch .save (self .args , args_path )
10901179 self ._precompile_args_path = args_path
10911180 exit_stack .callback (self .cleanup )
1181+
1182+ if not self ._try_load_checkpoint ():
1183+ self ._init_search ()
10921184 try :
10931185 best = self ._autotune ()
1186+ self ._cleanup_checkpoint ()
10941187 finally :
10951188 self ._finalize_autotune_metrics ()
10961189 end = time .perf_counter ()
@@ -1112,6 +1205,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11121205 print (triton_code , file = sys .stderr )
11131206 return best
11141207
1208+ def _init_search (self ) -> None :
1209+ """
1210+ Initialize the search state for a fresh autotuning run.
1211+
1212+ This method is called when starting autotuning without a checkpoint.
1213+ Subclasses should override this to set up initial population and state.
1214+ After this method, _current_generation should be set to the generation
1215+ that _autotune() should start its loop from.
1216+ """
1217+
11151218 def _autotune (self ) -> Config :
11161219 """
11171220 Abstract method to perform the actual autotuning.
@@ -1123,6 +1226,68 @@ def _autotune(self) -> Config:
11231226 """
11241227 raise NotImplementedError
11251228
1229+ def save_checkpoint (self ) -> Path | None :
1230+ """
1231+ Save current autotuner state to checkpoint file.
1232+
1233+ Only saves when autotune_checkpoint_dir is set (opt-in).
1234+ Overwrites the same file each generation (keyed by stable hash).
1235+ Uses pickle to serialize the entire autotuner object (minus unpicklable
1236+ fields excluded by __getstate__).
1237+
1238+ Returns:
1239+ Path to saved checkpoint file, or None if not saved
1240+ """
1241+ from ..runtime .kernel import BoundKernel
1242+
1243+ # External kernels don't support caching/checkpointing
1244+ if not isinstance (self .kernel , BoundKernel ):
1245+ return None
1246+
1247+ if not self .kernel .is_cacheable ():
1248+ return None
1249+
1250+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1251+ if checkpoint_dir_str is None :
1252+ return None # Opt-in: no dir set, no saving
1253+
1254+ stable_hash = self ._get_stable_hash ()
1255+ checkpoint_dir = Path (checkpoint_dir_str )
1256+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
1257+ checkpoint_path = checkpoint_dir / f"{ stable_hash } .pt"
1258+
1259+ # Atomic write using temp file + rename
1260+ tmp = checkpoint_dir / f".tmp.{ stable_hash } .{ os .getpid ()} "
1261+ with open (tmp , "wb" ) as f :
1262+ pickle .dump (self , f )
1263+ os .replace (tmp , checkpoint_path )
1264+
1265+ self .log (f"Checkpoint saved: { checkpoint_path } " )
1266+ return checkpoint_path
1267+
1268+ def _cleanup_checkpoint (self ) -> None :
1269+ """Delete checkpoint file on successful autotune completion.
1270+
1271+ Checkpoints are ephemeral in-progress state. Once autotuning
1272+ completes successfully, the result is cached normally and the
1273+ checkpoint is no longer needed.
1274+ """
1275+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1276+ if checkpoint_dir_str is None :
1277+ return
1278+
1279+ stable_hash = self ._get_stable_hash ()
1280+ checkpoint_file = Path (checkpoint_dir_str ) / f"{ stable_hash } .pt"
1281+ if checkpoint_file .exists ():
1282+ checkpoint_file .unlink ()
1283+ self .log (f"Checkpoint cleaned up: { checkpoint_file } " )
1284+
1285+ # Clean up crash-recovery artifacts
1286+ for suffix in (".pending_config" , ".crashed_configs" ):
1287+ artifact = Path (checkpoint_dir_str ) / f"{ stable_hash } { suffix } "
1288+ if artifact .exists ():
1289+ artifact .unlink ()
1290+
11261291 def set_generation (self , generation : int ) -> None :
11271292 self ._autotune_metrics .num_generations = generation
11281293
@@ -1177,6 +1342,15 @@ class PopulationMember:
11771342 def perf (self ) -> float :
11781343 return self .perfs [- 1 ]
11791344
1345+ def __getstate__ (self ) -> dict [str , Any ]:
1346+ state = self .__dict__ .copy ()
1347+ state ["fn" ] = None # compiled functions are not picklable
1348+ return state
1349+
1350+ def __setstate__ (self , state : dict [str , Any ]) -> None :
1351+ self .__dict__ .update (state )
1352+ self .fn = _unset_fn
1353+
11801354
11811355def performance (member : PopulationMember ) -> float :
11821356 """
@@ -1570,6 +1744,14 @@ def rebenchmark_population(
15701744 members = self .population
15711745 self .rebenchmark ([p for p in members if self .should_rebenchmark (p )], desc = desc )
15721746
1747+ def set_generation (self , generation : int ) -> None :
1748+ if generation == self ._current_generation :
1749+ return
1750+ self ._current_generation = generation
1751+ super ().set_generation (generation )
1752+ if generation > 0 :
1753+ self .save_checkpoint ()
1754+
15731755 def statistics (self ) -> str :
15741756 """
15751757 Generate statistics for the current population.
@@ -1579,6 +1761,27 @@ def statistics(self) -> str:
15791761 """
15801762 return population_statistics (self .population )
15811763
1764+ def _recompile_after_checkpoint (self ) -> None :
1765+ """Recompile kernel functions for population members after checkpoint load."""
1766+ recompile_failures : list [tuple [PopulationMember , str ]] = []
1767+ for member in self .population :
1768+ if member .fn is _unset_fn and member .status == "ok" :
1769+ try :
1770+ member .fn = self .kernel .compile_config (
1771+ member .config , allow_print = False
1772+ )
1773+ except Exception as e :
1774+ member .fn = _unset_fn
1775+ member .status = "error"
1776+ member .perfs .append (inf ) # Ensure member won't be selected as best
1777+ recompile_failures .append ((member , str (e )))
1778+
1779+ if recompile_failures :
1780+ self .log (
1781+ f"Warning: { len (recompile_failures )} config(s) failed to recompile "
1782+ f"and will be skipped. First failure: { recompile_failures [0 ][1 ]} "
1783+ )
1784+
15821785 def run_finishing_phase (
15831786 self , best : PopulationMember , rounds : int
15841787 ) -> PopulationMember :
0 commit comments