1212import math
1313from math import inf
1414import os
15+ from pathlib import Path
16+ import pickle
1517import pprint
1618import random
1719import re
@@ -343,6 +345,15 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
343345 self .args : Sequence [object ] = args
344346 self .log = AutotuningLogger (self .settings )
345347 self .best_perf_so_far = inf
348+ self ._current_generation = 0
349+ self .counters : collections .Counter [str ] = collections .Counter ()
350+ self ._autotune_metrics : AutotuneMetrics = AutotuneMetrics (
351+ kernel_name = "" ,
352+ input_shapes = "" ,
353+ hardware = "" ,
354+ random_seed = 0 ,
355+ search_algorithm = type (self ).__name__ ,
356+ )
346357 self ._prepared = False
347358 self ._precompile_tmpdir : tempfile .TemporaryDirectory [str ] | None = None
348359 self ._precompile_args_path : str | None = None
@@ -406,6 +417,90 @@ def cleanup(self) -> None:
406417 self ._precompile_args_path = None
407418 self ._precompile_result_counter = count ()
408419
420+ # Fields excluded from pickle checkpoints: unpicklable infrastructure,
421+ # fields recomputed by _prepare(), and fields loaded separately.
422+ _CHECKPOINT_EXCLUDE = frozenset (
423+ {
424+ # Unpicklable infrastructure
425+ "kernel" ,
426+ "args" ,
427+ "log" ,
428+ "settings" ,
429+ "config_spec" ,
430+ "_precompile_tmpdir" ,
431+ "_precompile_args_path" ,
432+ "_precompile_result_counter" ,
433+ # Recomputed by _prepare() before checkpoint load
434+ "_baseline_output" ,
435+ "_baseline_post_args" ,
436+ "_mutated_arg_indices" ,
437+ "_effective_atol" ,
438+ "_effective_rtol" ,
439+ "_jobs" ,
440+ "_autotune_metrics" ,
441+ "_prepared" ,
442+ "_skip_cache" ,
443+ # Loaded separately via _load_crashed_configs()
444+ "_crashed_config_strs" ,
445+ }
446+ )
447+
448+ def __getstate__ (self ) -> dict [str , Any ]:
449+ return {
450+ k : v for k , v in self .__dict__ .items () if k not in self ._CHECKPOINT_EXCLUDE
451+ }
452+
453+ _stable_hash : str | None = None
454+
455+ def _get_stable_hash (self ) -> str :
456+ """Get the full stable hash for this kernel's cache key (cached)."""
457+ if self ._stable_hash is None :
458+ from .local_cache import LocalAutotuneCache
459+
460+ self ._stable_hash = LocalAutotuneCache (self )._generate_key ().stable_hash ()
461+ return self ._stable_hash
462+
463+ def _try_load_checkpoint (self ) -> bool :
464+ """Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
465+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
466+ if checkpoint_dir_str is None :
467+ return False
468+
469+ checkpoint_dir = Path (checkpoint_dir_str )
470+ stable_hash = self ._get_stable_hash ()
471+ checkpoint_file = checkpoint_dir / f"{ stable_hash } .pt"
472+
473+ if not checkpoint_file .exists ():
474+ return False # No matching checkpoint; start fresh
475+
476+ # Matching file exists, attempt to load
477+ self .log (f"Resuming from checkpoint: { checkpoint_file } " )
478+ try :
479+ with open (checkpoint_file , "rb" ) as f :
480+ loaded = pickle .load (f )
481+ except Exception as e :
482+ raise exc .CheckpointError (
483+ f"Failed to load checkpoint file '{ checkpoint_file } ': { e } \n "
484+ f"The file may be corrupted. Delete it to start fresh."
485+ ) from e
486+
487+ # Validate stable hash matches (guards against renamed/copied files)
488+ loaded_hash = getattr (loaded , "_stable_hash" , None )
489+ if loaded_hash is not None and loaded_hash != self ._get_stable_hash ():
490+ raise exc .CheckpointError (
491+ "Checkpoint is incompatible: kernel, hardware, or input shapes "
492+ "may have changed."
493+ )
494+
495+ # Copy loaded search state into self (self already has kernel, args,
496+ # log, etc. from __init__ and _prepare())
497+ self .__dict__ .update (loaded .__dict__ )
498+ self ._recompile_after_checkpoint ()
499+ self .log (f"Resumed at generation { self ._current_generation } " )
500+ return True
501+
502+ def _recompile_after_checkpoint (self ) -> None :
503+ """Recompile after loading a checkpoint. Override in subclasses."""
409504 def _compute_baseline (
410505 self ,
411506 ) -> tuple [object , Sequence [int ], Sequence [object ] | None ]:
@@ -629,6 +724,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
629724 The performance of the configuration in ms.
630725 """
631726 self ._autotune_metrics .num_configs_tested += 1
727+ self .counters ["benchmark" ] += 1
632728 self .log .debug (lambda : f"Running benchmark for { config !r} " )
633729 _captured_output : list [str ] = ["" ]
634730 _capture_ctx = (
@@ -1089,8 +1185,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10891185 torch .save (self .args , args_path )
10901186 self ._precompile_args_path = args_path
10911187 exit_stack .callback (self .cleanup )
1188+
1189+ if not self ._try_load_checkpoint ():
1190+ self ._init_search ()
10921191 try :
10931192 best = self ._autotune ()
1193+ self ._cleanup_checkpoint ()
10941194 finally :
10951195 self ._finalize_autotune_metrics ()
10961196 end = time .perf_counter ()
@@ -1112,6 +1212,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11121212 print (triton_code , file = sys .stderr )
11131213 return best
11141214
1215+ def _init_search (self ) -> None :
1216+ """
1217+ Initialize the search state for a fresh autotuning run.
1218+
1219+ This method is called when starting autotuning without a checkpoint.
1220+ Subclasses should override this to set up initial population and state.
1221+ After this method, _current_generation should be set to the generation
1222+ that _autotune() should start its loop from.
1223+ """
1224+
11151225 def _autotune (self ) -> Config :
11161226 """
11171227 Abstract method to perform the actual autotuning.
@@ -1123,6 +1233,68 @@ def _autotune(self) -> Config:
11231233 """
11241234 raise NotImplementedError
11251235
1236+ def save_checkpoint (self ) -> Path | None :
1237+ """
1238+ Save current autotuner state to checkpoint file.
1239+
1240+ Only saves when autotune_checkpoint_dir is set (opt-in).
1241+ Overwrites the same file each generation (keyed by stable hash).
1242+ Uses pickle to serialize the entire autotuner object (minus unpicklable
1243+ fields excluded by __getstate__).
1244+
1245+ Returns:
1246+ Path to saved checkpoint file, or None if not saved
1247+ """
1248+ from ..runtime .kernel import BoundKernel
1249+
1250+ # External kernels don't support caching/checkpointing
1251+ if not isinstance (self .kernel , BoundKernel ):
1252+ return None
1253+
1254+ if not self .kernel .is_cacheable ():
1255+ return None
1256+
1257+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1258+ if checkpoint_dir_str is None :
1259+ return None # Opt-in: no dir set, no saving
1260+
1261+ stable_hash = self ._get_stable_hash ()
1262+ checkpoint_dir = Path (checkpoint_dir_str )
1263+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
1264+ checkpoint_path = checkpoint_dir / f"{ stable_hash } .pt"
1265+
1266+ # Atomic write using temp file + rename
1267+ tmp = checkpoint_dir / f".tmp.{ stable_hash } .{ os .getpid ()} "
1268+ with open (tmp , "wb" ) as f :
1269+ pickle .dump (self , f )
1270+ os .replace (tmp , checkpoint_path )
1271+
1272+ self .log (f"Checkpoint saved: { checkpoint_path } " )
1273+ return checkpoint_path
1274+
1275+ def _cleanup_checkpoint (self ) -> None :
1276+ """Delete checkpoint file on successful autotune completion.
1277+
1278+ Checkpoints are ephemeral in-progress state. Once autotuning
1279+ completes successfully, the result is cached normally and the
1280+ checkpoint is no longer needed.
1281+ """
1282+ checkpoint_dir_str = self .settings .autotune_checkpoint_dir
1283+ if checkpoint_dir_str is None :
1284+ return
1285+
1286+ stable_hash = self ._get_stable_hash ()
1287+ checkpoint_file = Path (checkpoint_dir_str ) / f"{ stable_hash } .pt"
1288+ if checkpoint_file .exists ():
1289+ checkpoint_file .unlink ()
1290+ self .log (f"Checkpoint cleaned up: { checkpoint_file } " )
1291+
1292+ # Clean up crash-recovery artifacts
1293+ for suffix in (".pending_config" , ".crashed_configs" ):
1294+ artifact = Path (checkpoint_dir_str ) / f"{ stable_hash } { suffix } "
1295+ if artifact .exists ():
1296+ artifact .unlink ()
1297+
11261298 def set_generation (self , generation : int ) -> None :
11271299 self ._autotune_metrics .num_generations = generation
11281300
@@ -1177,6 +1349,15 @@ class PopulationMember:
11771349 def perf (self ) -> float :
11781350 return self .perfs [- 1 ]
11791351
1352+ def __getstate__ (self ) -> dict [str , Any ]:
1353+ state = self .__dict__ .copy ()
1354+ state ["fn" ] = None # compiled functions are not picklable
1355+ return state
1356+
1357+ def __setstate__ (self , state : dict [str , Any ]) -> None :
1358+ self .__dict__ .update (state )
1359+ self .fn = _unset_fn
1360+
11801361
11811362def performance (member : PopulationMember ) -> float :
11821363 """
@@ -1570,6 +1751,14 @@ def rebenchmark_population(
15701751 members = self .population
15711752 self .rebenchmark ([p for p in members if self .should_rebenchmark (p )], desc = desc )
15721753
1754+ def set_generation (self , generation : int ) -> None :
1755+ if generation == self ._current_generation :
1756+ return
1757+ self ._current_generation = generation
1758+ super ().set_generation (generation )
1759+ if generation > 0 :
1760+ self .save_checkpoint ()
1761+
15731762 def statistics (self ) -> str :
15741763 """
15751764 Generate statistics for the current population.
@@ -1579,6 +1768,27 @@ def statistics(self) -> str:
15791768 """
15801769 return population_statistics (self .population )
15811770
1771+ def _recompile_after_checkpoint (self ) -> None :
1772+ """Recompile kernel functions for population members after checkpoint load."""
1773+ recompile_failures : list [tuple [PopulationMember , str ]] = []
1774+ for member in self .population :
1775+ if member .fn is _unset_fn and member .status == "ok" :
1776+ try :
1777+ member .fn = self .kernel .compile_config (
1778+ member .config , allow_print = False
1779+ )
1780+ except Exception as e :
1781+ member .fn = _unset_fn
1782+ member .status = "error"
1783+ member .perfs .append (inf ) # Ensure member won't be selected as best
1784+ recompile_failures .append ((member , str (e )))
1785+
1786+ if recompile_failures :
1787+ self .log (
1788+ f"Warning: { len (recompile_failures )} config(s) failed to recompile "
1789+ f"and will be skipped. First failure: { recompile_failures [0 ][1 ]} "
1790+ )
1791+
15821792 def run_finishing_phase (
15831793 self , best : PopulationMember , rounds : int
15841794 ) -> PopulationMember :
0 commit comments