Skip to content

Commit 9957ed4

Browse files
committed
[Auto-Recovery] Add checkpoint save/load/resume
Add opt-in checkpoint support gated behind HELION_AUTOTUNE_CHECKPOINT_DIR. When set, the autotuner saves in-progress state each generation and can resume from a checkpoint on subsequent runs. The checkpoint file is deleted on successful completion. Includes pickle serialization support for BaseSearch and PopulationMember, stable-hash-based checkpoint file naming, atomic writes, and kernel recompilation on checkpoint load. stack-info: PR: #1947, branch: yf225/stack/96
1 parent 4f24dc9 commit 9957ed4

File tree

5 files changed

+263
-1
lines changed

5 files changed

+263
-1
lines changed

docs/api/settings.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
209209
Each preset also sets a default initial population strategy (see :doc:`../deployment_autotuning` for details).
210210
Users can still override individual ``autotune_*`` settings; explicit values win over the preset. Controlled by ``HELION_AUTOTUNE_EFFORT``.
211211
212+
.. autoattribute:: Settings.autotune_checkpoint_dir
213+
214+
Directory path for saving and resuming autotuning checkpoints. When set, the autotuner
215+
saves in-progress state to ``{dir}/{stable_hash}.pt`` and auto-discovers matching
216+
checkpoints on subsequent runs. The checkpoint file is deleted on successful completion.
217+
When unset (default), no checkpoints are saved or loaded (opt-in).
218+
Controlled by ``HELION_AUTOTUNE_CHECKPOINT_DIR``.
219+
212220
.. autoattribute:: Settings.autotune_best_available_max_configs
213221
214222
Maximum number of cached configs to use when seeding the initial population with the ``from_best_available`` strategy.
@@ -323,6 +331,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"LFBOTreeSearch"`` (default),
323331
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
324332
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
325333
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
334+
| ``HELION_AUTOTUNE_CHECKPOINT_DIR`` | ``autotune_checkpoint_dir`` | Directory path for saving/resuming autotuning checkpoints (opt-in). |
326335
| ``TRITON_STORE_BINARY_ONLY`` | Triton (autotuning) | Set to ``1`` during autotuning to skip Triton intermediate IRs, reducing cache size ~40%. Set to ``0`` to retain IRs for debugging. |
327336
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
328337
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, skip both reading and writing the autotuning cache entirely. |

docs/deployment_autotuning.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,29 @@ Related settings for `from_best_available` (see {doc}`api/settings`):
183183
| `autotune_best_available_max_configs` | `HELION_BEST_AVAILABLE_MAX_CONFIGS` | 20 | Maximum cached configs to seed |
184184
| `autotune_best_available_max_cache_scan` | `HELION_BEST_AVAILABLE_MAX_CACHE_SCAN` | 500 | Maximum cache files to scan |
185185

186+
### Checkpointing Long-Running Autotuning
187+
188+
For very long autotuning sessions, you can save and resume state using
189+
checkpoints. This is useful when tuning might be interrupted (e.g., preemptible
190+
instances) or when you want to continue tuning from a previous unfinished run.
191+
192+
Set the `HELION_AUTOTUNE_CHECKPOINT_DIR` environment variable to a directory
193+
path. The autotuner will periodically save checkpoints there, keyed by the
194+
kernel's stable hash. If interrupted, re-run with the same directory to resume
195+
automatically. On successful completion, the checkpoint file is cleaned up.
196+
197+
```bash
198+
# Enable checkpointing to a directory:
199+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
200+
201+
# If interrupted, just re-run with the same directory to resume:
202+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
203+
```
204+
205+
Without `HELION_AUTOTUNE_CHECKPOINT_DIR`, no checkpoints are saved (opt-in).
206+
Multiple kernels can safely use the same directory — each kernel writes to a
207+
file named by its unique stable hash.
208+
186209
## Deploy a Single Config
187210

188211
If one configuration wins for every production call, bake it into the decorator:

helion/autotuner/base_search.py

Lines changed: 213 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import math
1313
from math import inf
1414
import os
15+
from pathlib import Path
16+
import pickle
1517
import pprint
1618
import random
1719
import re
@@ -407,6 +409,114 @@ def cleanup(self) -> None:
407409
self._precompile_args_path = None
408410
self._precompile_result_counter = count()
409411

412+
# Fields excluded from pickle checkpoints: unpicklable infrastructure,
413+
# fields recomputed by _prepare(), and fields loaded separately.
414+
_CHECKPOINT_EXCLUDE = frozenset(
415+
{
416+
# Unpicklable infrastructure
417+
"kernel",
418+
"args",
419+
"log",
420+
"settings",
421+
"config_spec",
422+
"_precompile_tmpdir",
423+
"_precompile_args_path",
424+
"_precompile_result_counter",
425+
# Recomputed by _prepare() before checkpoint load
426+
"_baseline_output",
427+
"_baseline_post_args",
428+
"_mutated_arg_indices",
429+
"_effective_atol",
430+
"_effective_rtol",
431+
"_jobs",
432+
"_autotune_metrics",
433+
"_prepared",
434+
"_skip_cache",
435+
# Loaded separately via _load_crashed_configs()
436+
"_crashed_config_strs",
437+
}
438+
)
439+
440+
def __getstate__(self) -> dict[str, Any]:
441+
return {
442+
k: v for k, v in self.__dict__.items() if k not in self._CHECKPOINT_EXCLUDE
443+
}
444+
445+
_stable_hash: str | None = None
446+
447+
def _get_stable_hash(self) -> str:
448+
"""Get the full stable hash for this kernel's cache key (cached)."""
449+
if self._stable_hash is None:
450+
from .local_cache import LocalAutotuneCache
451+
452+
self._stable_hash = LocalAutotuneCache(self)._generate_key().stable_hash()
453+
return self._stable_hash
454+
455+
def _try_load_checkpoint(self) -> bool:
456+
"""Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
457+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
458+
if checkpoint_dir_str is None:
459+
return False
460+
461+
checkpoint_dir = Path(checkpoint_dir_str)
462+
stable_hash = self._get_stable_hash()
463+
checkpoint_file = checkpoint_dir / f"{stable_hash}.pt"
464+
465+
if not checkpoint_file.exists():
466+
return False # No matching checkpoint; start fresh
467+
468+
# Matching file exists, attempt to load
469+
self.log(f"Resuming from checkpoint: {checkpoint_file}")
470+
try:
471+
with open(checkpoint_file, "rb") as f:
472+
loaded = pickle.load(f)
473+
except Exception as e:
474+
raise exc.CheckpointError(
475+
f"Failed to load checkpoint file '{checkpoint_file}': {e}\n"
476+
f"The file may be corrupted. Delete it to start fresh."
477+
) from e
478+
479+
# Validate stable hash matches (guards against renamed/copied files)
480+
loaded_hash = getattr(loaded, "_stable_hash", None)
481+
if loaded_hash is not None and loaded_hash != self._get_stable_hash():
482+
raise exc.CheckpointError(
483+
"Checkpoint is incompatible: kernel, hardware, or input shapes "
484+
"may have changed."
485+
)
486+
487+
# Copy loaded search state into self (self already has kernel, args,
488+
# log, etc. from __init__ and _prepare())
489+
self.__dict__.update(loaded.__dict__)
490+
491+
self.log(f"Resumed at generation {self._current_generation}")
492+
return True
493+
494+
def _load_crashed_configs(self) -> None:
495+
"""Load crashed configs from {hash}.crashed_configs (written by crash-recovery script)."""
496+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
497+
if checkpoint_dir_str is None:
498+
return
499+
crashed_configs_path = (
500+
Path(checkpoint_dir_str) / f"{self._get_stable_hash()}.crashed_configs"
501+
)
502+
if crashed_configs_path.exists():
503+
self._crashed_config_strs |= {
504+
line.strip()
505+
for line in crashed_configs_path.read_text().splitlines()
506+
if line.strip()
507+
}
508+
if self._crashed_config_strs:
509+
self.log(
510+
f"Loaded {len(self._crashed_config_strs)} crashed config(s) to skip"
511+
)
512+
513+
def _get_pending_config_path(self) -> Path | None:
514+
"""Get path for pending-config sentinel, or None if checkpointing disabled."""
515+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
516+
if checkpoint_dir_str is None:
517+
return None
518+
return Path(checkpoint_dir_str) / f"{self._get_stable_hash()}.pending_config"
519+
410520
def _compute_baseline(
411521
self,
412522
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -1091,9 +1201,13 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10911201
self._precompile_args_path = args_path
10921202
exit_stack.callback(self.cleanup)
10931203

1094-
self._init_search()
1204+
checkpoint_enabled = self.settings.autotune_checkpoint_dir is not None
1205+
if not (checkpoint_enabled and self._try_load_checkpoint()):
1206+
self._init_search()
10951207
try:
10961208
best = self._autotune()
1209+
if checkpoint_enabled:
1210+
self._cleanup_checkpoint()
10971211
finally:
10981212
self._finalize_autotune_metrics()
10991213
end = time.perf_counter()
@@ -1119,6 +1233,7 @@ def _init_search(self) -> None:
11191233
"""
11201234
Initialize the search state for a fresh autotuning run.
11211235
1236+
This method is called when starting autotuning without a checkpoint.
11221237
Subclasses should override this to set up initial population and state.
11231238
After this method, _current_generation should be set to the generation
11241239
that _autotune() should start its loop from.
@@ -1135,6 +1250,68 @@ def _autotune(self) -> Config:
11351250
"""
11361251
raise NotImplementedError
11371252

1253+
def save_checkpoint(self) -> Path | None:
1254+
"""
1255+
Save current autotuner state to checkpoint file.
1256+
1257+
Only saves when autotune_checkpoint_dir is set (opt-in).
1258+
Overwrites the same file each generation (keyed by stable hash).
1259+
Uses pickle to serialize the entire autotuner object (minus unpicklable
1260+
fields excluded by __getstate__).
1261+
1262+
Returns:
1263+
Path to saved checkpoint file, or None if not saved
1264+
"""
1265+
from ..runtime.kernel import BoundKernel
1266+
1267+
# External kernels don't support caching/checkpointing
1268+
if not isinstance(self.kernel, BoundKernel):
1269+
return None
1270+
1271+
if not self.kernel.is_cacheable():
1272+
return None
1273+
1274+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1275+
if checkpoint_dir_str is None:
1276+
return None # Opt-in: no dir set, no saving
1277+
1278+
stable_hash = self._get_stable_hash()
1279+
checkpoint_dir = Path(checkpoint_dir_str)
1280+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
1281+
checkpoint_path = checkpoint_dir / f"{stable_hash}.pt"
1282+
1283+
# Atomic write using temp file + rename
1284+
tmp = checkpoint_dir / f".tmp.{stable_hash}.{os.getpid()}"
1285+
with open(tmp, "wb") as f:
1286+
pickle.dump(self, f)
1287+
os.replace(tmp, checkpoint_path)
1288+
1289+
self.log(f"Checkpoint saved: {checkpoint_path}")
1290+
return checkpoint_path
1291+
1292+
def _cleanup_checkpoint(self) -> None:
1293+
"""Delete checkpoint file on successful autotune completion.
1294+
1295+
Checkpoints are ephemeral in-progress state. Once autotuning
1296+
completes successfully, the result is cached normally and the
1297+
checkpoint is no longer needed.
1298+
"""
1299+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1300+
if checkpoint_dir_str is None:
1301+
return
1302+
1303+
stable_hash = self._get_stable_hash()
1304+
checkpoint_file = Path(checkpoint_dir_str) / f"{stable_hash}.pt"
1305+
if checkpoint_file.exists():
1306+
checkpoint_file.unlink()
1307+
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
1308+
1309+
# Clean up crash-recovery artifacts
1310+
for suffix in (".pending_config", ".crashed_configs"):
1311+
artifact = Path(checkpoint_dir_str) / f"{stable_hash}{suffix}"
1312+
if artifact.exists():
1313+
artifact.unlink()
1314+
11381315
def set_generation(self, generation: int) -> None:
11391316
self._autotune_metrics.num_generations = generation
11401317

@@ -1189,6 +1366,15 @@ class PopulationMember:
11891366
def perf(self) -> float:
11901367
return self.perfs[-1]
11911368

1369+
def __getstate__(self) -> dict[str, Any]:
1370+
state = self.__dict__.copy()
1371+
state["fn"] = None # compiled functions are not picklable
1372+
return state
1373+
1374+
def __setstate__(self, state: dict[str, Any]) -> None:
1375+
self.__dict__.update(state)
1376+
self.fn = _unset_fn
1377+
11921378

11931379
def performance(member: PopulationMember) -> float:
11941380
"""
@@ -1587,6 +1773,8 @@ def set_generation(self, generation: int) -> None:
15871773
return
15881774
self._current_generation = generation
15891775
super().set_generation(generation)
1776+
if generation > 0 and self.settings.autotune_checkpoint_dir is not None:
1777+
self.save_checkpoint()
15901778

15911779
def statistics(self) -> str:
15921780
"""
@@ -1597,6 +1785,30 @@ def statistics(self) -> str:
15971785
"""
15981786
return population_statistics(self.population)
15991787

1788+
def _try_load_checkpoint(self) -> bool:
1789+
if not super()._try_load_checkpoint():
1790+
return False
1791+
# Recompile kernel functions for population members after checkpoint load
1792+
recompile_failures: list[tuple[PopulationMember, str]] = []
1793+
for member in self.population:
1794+
if member.fn is _unset_fn and member.status == "ok":
1795+
try:
1796+
member.fn = self.kernel.compile_config(
1797+
member.config, allow_print=False
1798+
)
1799+
except Exception as e:
1800+
member.fn = _unset_fn
1801+
member.status = "error"
1802+
member.perfs.append(inf) # Ensure member won't be selected as best
1803+
recompile_failures.append((member, str(e)))
1804+
1805+
if recompile_failures:
1806+
self.log(
1807+
f"Warning: {len(recompile_failures)} config(s) failed to recompile "
1808+
f"and will be skipped. First failure: {recompile_failures[0][1]}"
1809+
)
1810+
return True
1811+
16001812
def run_finishing_phase(
16011813
self, best: PopulationMember, rounds: int
16021814
) -> PopulationMember:

helion/exc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ class AutotuneError(BaseError):
5858
message = "{0}"
5959

6060

61+
class CheckpointError(AutotuneError):
62+
"""Exception raised when checkpoint loading/saving fails."""
63+
64+
message = "{0}"
65+
66+
6167
class BackendImplementationMissing(BaseError):
6268
message = "Backend '{backend}' is missing required implementation: {detail}"
6369

helion/runtime/settings.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,11 @@ class _Settings:
492492
autotune_baseline_rtol: float | None = None
493493
autotune_baseline_accuracy_check_fn: Callable[[object, object], None] | None = None
494494
autotune_benchmark_fn: Callable[..., list[float]] | None = None
495+
autotune_checkpoint_dir: str | None = dataclasses.field(
496+
default_factory=functools.partial(
497+
os.environ.get, "HELION_AUTOTUNE_CHECKPOINT_DIR"
498+
)
499+
)
495500
autotune_best_available_max_configs: int = dataclasses.field(
496501
default_factory=functools.partial(
497502
_env_get_int, "HELION_BEST_AVAILABLE_MAX_CONFIGS", 20
@@ -640,6 +645,13 @@ class Settings(_Settings):
640645
"(fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None) -> list[float]. "
641646
"If None (default), uses the built-in benchmark function."
642647
),
648+
"autotune_checkpoint_dir": (
649+
"Directory path for saving and resuming autotuning checkpoints. "
650+
"When set, the autotuner saves in-progress state to this directory using the "
651+
"kernel's stable hash as the filename, and auto-discovers matching checkpoints "
652+
"on subsequent runs. The checkpoint file is deleted on successful completion. "
653+
"When unset (default), no checkpoints are saved or loaded."
654+
),
643655
"autotune_best_available_max_configs": (
644656
"Maximum number of cached configs to use for FROM_BEST_AVAILABLE initial population strategy. "
645657
"Set HELION_BEST_AVAILABLE_MAX_CONFIGS=N to override. Default is 20."

0 commit comments

Comments
 (0)