Skip to content

Commit e556287

Browse files
committed
[Autotuner] Auto-checkpoint feature and ability to resume from checkpoint
Fixes #1330. Internal customers had a lot of pain with IMA errors and they also feel that spawn mode is too much overhead causing autotuning time to be extra long. This PR stack adds an auto-recovery feature by checkpointing regularly (which is by itself useful for server crash scenarios mentioned in #1330) and then automatically start a new autotune process using previously saved checkpoint if there is an IMA error (next PR). stack-info: PR: #1920, branch: yf225/stack/90
1 parent ff57fe4 commit e556287

File tree

10 files changed

+1279
-211
lines changed

10 files changed

+1279
-211
lines changed

docs/api/settings.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
209209
Each preset also sets a default initial population strategy (see :doc:`../deployment_autotuning` for details).
210210
Users can still override individual ``autotune_*`` settings; explicit values win over the preset. Controlled by ``HELION_AUTOTUNE_EFFORT``.
211211
212+
.. autoattribute:: Settings.autotune_checkpoint_dir
213+
214+
Directory path for saving and resuming autotuning checkpoints. When set, the autotuner
215+
saves in-progress state to ``{dir}/{stable_hash}.pt`` and auto-discovers matching
216+
checkpoints on subsequent runs. The checkpoint file is deleted on successful completion.
217+
When unset (default), no checkpoints are saved or loaded (opt-in).
218+
Controlled by ``HELION_AUTOTUNE_CHECKPOINT_DIR``.
219+
212220
.. autoattribute:: Settings.autotune_best_available_max_configs
213221
214222
Maximum number of cached configs to use when seeding the initial population with the ``from_best_available`` strategy.
@@ -323,6 +331,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"LFBOTreeSearch"`` (default),
323331
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
324332
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
325333
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
334+
| ``HELION_AUTOTUNE_CHECKPOINT_DIR`` | ``autotune_checkpoint_dir`` | Directory path for saving/resuming autotuning checkpoints (opt-in). |
326335
| ``TRITON_STORE_BINARY_ONLY`` | Triton (autotuning) | Set to ``1`` during autotuning to skip Triton intermediate IRs, reducing cache size ~40%. Set to ``0`` to retain IRs for debugging. |
327336
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
328337
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, skip both reading and writing the autotuning cache entirely. |

docs/deployment_autotuning.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,29 @@ Related settings for `from_best_available` (see {doc}`api/settings`):
183183
| `autotune_best_available_max_configs` | `HELION_BEST_AVAILABLE_MAX_CONFIGS` | 20 | Maximum cached configs to seed |
184184
| `autotune_best_available_max_cache_scan` | `HELION_BEST_AVAILABLE_MAX_CACHE_SCAN` | 500 | Maximum cache files to scan |
185185

186+
### Checkpointing Long-Running Autotuning
187+
188+
For very long autotuning sessions, you can save and resume state using
189+
checkpoints. This is useful when tuning might be interrupted (e.g., preemptible
190+
instances) or when you want to continue tuning from a previous unfinished run.
191+
192+
Set the `HELION_AUTOTUNE_CHECKPOINT_DIR` environment variable to a directory
193+
path. The autotuner will periodically save checkpoints there, keyed by the
194+
kernel's stable hash. If interrupted, re-run with the same directory to resume
195+
automatically. On successful completion, the checkpoint file is cleaned up.
196+
197+
```bash
198+
# Enable checkpointing to a directory:
199+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
200+
201+
# If interrupted, just re-run with the same directory to resume:
202+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
203+
```
204+
205+
Without `HELION_AUTOTUNE_CHECKPOINT_DIR`, no checkpoints are saved (opt-in).
206+
Multiple kernels can safely use the same directory — each kernel writes to a
207+
file named by its unique stable hash.
208+
186209
## Deploy a Single Config
187210

188211
If one configuration wins for every production call, bake it into the decorator:

helion/autotuner/base_search.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import math
1313
from math import inf
1414
import os
15+
from pathlib import Path
16+
import pickle
1517
import pprint
1618
import random
1719
import re
@@ -343,6 +345,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
343345
self.args: Sequence[object] = args
344346
self.log = AutotuningLogger(self.settings)
345347
self.best_perf_so_far = inf
348+
self._current_generation = 0
346349
self._prepared = False
347350
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
348351
self._precompile_args_path: str | None = None
@@ -406,6 +409,91 @@ def cleanup(self) -> None:
406409
self._precompile_args_path = None
407410
self._precompile_result_counter = count()
408411

412+
# Fields excluded from pickle checkpoints: unpicklable infrastructure,
413+
# fields recomputed by _prepare(), and fields loaded separately.
414+
_CHECKPOINT_EXCLUDE = frozenset(
415+
{
416+
# Unpicklable infrastructure
417+
"kernel",
418+
"args",
419+
"log",
420+
"settings",
421+
"config_spec",
422+
"_precompile_tmpdir",
423+
"_precompile_args_path",
424+
"_precompile_result_counter",
425+
# Recomputed by _prepare() before checkpoint load
426+
"_baseline_output",
427+
"_baseline_post_args",
428+
"_mutated_arg_indices",
429+
"_effective_atol",
430+
"_effective_rtol",
431+
"_jobs",
432+
"_autotune_metrics",
433+
"_prepared",
434+
"_skip_cache",
435+
# Loaded separately via _load_crashed_configs()
436+
"_crashed_config_strs",
437+
}
438+
)
439+
440+
def __getstate__(self) -> dict[str, Any]:
441+
return {
442+
k: v for k, v in self.__dict__.items() if k not in self._CHECKPOINT_EXCLUDE
443+
}
444+
445+
_stable_hash: str | None = None
446+
447+
def _get_stable_hash(self) -> str:
448+
"""Get the full stable hash for this kernel's cache key (cached)."""
449+
if self._stable_hash is None:
450+
from .local_cache import LocalAutotuneCache
451+
452+
self._stable_hash = LocalAutotuneCache(self)._generate_key().stable_hash()
453+
return self._stable_hash
454+
455+
def _try_load_checkpoint(self) -> bool:
456+
"""Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
457+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
458+
if checkpoint_dir_str is None:
459+
return False
460+
461+
checkpoint_dir = Path(checkpoint_dir_str)
462+
stable_hash = self._get_stable_hash()
463+
checkpoint_file = checkpoint_dir / f"{stable_hash}.pt"
464+
465+
if not checkpoint_file.exists():
466+
return False # No matching checkpoint; start fresh
467+
468+
# Matching file exists, attempt to load
469+
self.log(f"Resuming from checkpoint: {checkpoint_file}")
470+
try:
471+
with open(checkpoint_file, "rb") as f:
472+
loaded = pickle.load(f)
473+
except Exception as e:
474+
raise exc.CheckpointError(
475+
f"Failed to load checkpoint file '{checkpoint_file}': {e}\n"
476+
f"The file may be corrupted. Delete it to start fresh."
477+
) from e
478+
479+
# Validate stable hash matches (guards against renamed/copied files)
480+
loaded_hash = getattr(loaded, "_stable_hash", None)
481+
if loaded_hash is not None and loaded_hash != self._get_stable_hash():
482+
raise exc.CheckpointError(
483+
"Checkpoint is incompatible: kernel, hardware, or input shapes "
484+
"may have changed."
485+
)
486+
487+
# Copy loaded search state into self (self already has kernel, args,
488+
# log, etc. from __init__ and _prepare())
489+
self.__dict__.update(loaded.__dict__)
490+
491+
self._recompile_after_checkpoint()
492+
self.log(f"Resumed at generation {self._current_generation}")
493+
return True
494+
495+
def _recompile_after_checkpoint(self) -> None:
496+
"""Recompile after loading a checkpoint. Override in subclasses."""
409497
def _compute_baseline(
410498
self,
411499
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -1089,8 +1177,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10891177
torch.save(self.args, args_path)
10901178
self._precompile_args_path = args_path
10911179
exit_stack.callback(self.cleanup)
1180+
1181+
if not self._try_load_checkpoint():
1182+
self._init_search()
10921183
try:
10931184
best = self._autotune()
1185+
self._cleanup_checkpoint()
10941186
finally:
10951187
self._finalize_autotune_metrics()
10961188
end = time.perf_counter()
@@ -1112,6 +1204,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11121204
print(triton_code, file=sys.stderr)
11131205
return best
11141206

1207+
def _init_search(self) -> None:
1208+
"""
1209+
Initialize the search state for a fresh autotuning run.
1210+
1211+
This method is called when starting autotuning without a checkpoint.
1212+
Subclasses should override this to set up initial population and state.
1213+
After this method, _current_generation should be set to the generation
1214+
that _autotune() should start its loop from.
1215+
"""
1216+
11151217
def _autotune(self) -> Config:
11161218
"""
11171219
Abstract method to perform the actual autotuning.
@@ -1123,6 +1225,68 @@ def _autotune(self) -> Config:
11231225
"""
11241226
raise NotImplementedError
11251227

1228+
def save_checkpoint(self) -> Path | None:
1229+
"""
1230+
Save current autotuner state to checkpoint file.
1231+
1232+
Only saves when autotune_checkpoint_dir is set (opt-in).
1233+
Overwrites the same file each generation (keyed by stable hash).
1234+
Uses pickle to serialize the entire autotuner object (minus unpicklable
1235+
fields excluded by __getstate__).
1236+
1237+
Returns:
1238+
Path to saved checkpoint file, or None if not saved
1239+
"""
1240+
from ..runtime.kernel import BoundKernel
1241+
1242+
# External kernels don't support caching/checkpointing
1243+
if not isinstance(self.kernel, BoundKernel):
1244+
return None
1245+
1246+
if not self.kernel.is_cacheable():
1247+
return None
1248+
1249+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1250+
if checkpoint_dir_str is None:
1251+
return None # Opt-in: no dir set, no saving
1252+
1253+
stable_hash = self._get_stable_hash()
1254+
checkpoint_dir = Path(checkpoint_dir_str)
1255+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
1256+
checkpoint_path = checkpoint_dir / f"{stable_hash}.pt"
1257+
1258+
# Atomic write using temp file + rename
1259+
tmp = checkpoint_dir / f".tmp.{stable_hash}.{os.getpid()}"
1260+
with open(tmp, "wb") as f:
1261+
pickle.dump(self, f)
1262+
os.replace(tmp, checkpoint_path)
1263+
1264+
self.log(f"Checkpoint saved: {checkpoint_path}")
1265+
return checkpoint_path
1266+
1267+
def _cleanup_checkpoint(self) -> None:
1268+
"""Delete checkpoint file on successful autotune completion.
1269+
1270+
Checkpoints are ephemeral in-progress state. Once autotuning
1271+
completes successfully, the result is cached normally and the
1272+
checkpoint is no longer needed.
1273+
"""
1274+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1275+
if checkpoint_dir_str is None:
1276+
return
1277+
1278+
stable_hash = self._get_stable_hash()
1279+
checkpoint_file = Path(checkpoint_dir_str) / f"{stable_hash}.pt"
1280+
if checkpoint_file.exists():
1281+
checkpoint_file.unlink()
1282+
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
1283+
1284+
# Clean up crash-recovery artifacts
1285+
for suffix in (".pending_config", ".crashed_configs"):
1286+
artifact = Path(checkpoint_dir_str) / f"{stable_hash}{suffix}"
1287+
if artifact.exists():
1288+
artifact.unlink()
1289+
11261290
def set_generation(self, generation: int) -> None:
11271291
self._autotune_metrics.num_generations = generation
11281292

@@ -1177,6 +1341,15 @@ class PopulationMember:
11771341
def perf(self) -> float:
11781342
return self.perfs[-1]
11791343

1344+
def __getstate__(self) -> dict[str, Any]:
1345+
state = self.__dict__.copy()
1346+
state["fn"] = None # compiled functions are not picklable
1347+
return state
1348+
1349+
def __setstate__(self, state: dict[str, Any]) -> None:
1350+
self.__dict__.update(state)
1351+
self.fn = _unset_fn
1352+
11801353

11811354
def performance(member: PopulationMember) -> float:
11821355
"""
@@ -1570,6 +1743,14 @@ def rebenchmark_population(
15701743
members = self.population
15711744
self.rebenchmark([p for p in members if self.should_rebenchmark(p)], desc=desc)
15721745

1746+
def set_generation(self, generation: int) -> None:
1747+
if generation == self._current_generation:
1748+
return
1749+
self._current_generation = generation
1750+
super().set_generation(generation)
1751+
if generation > 0:
1752+
self.save_checkpoint()
1753+
15731754
def statistics(self) -> str:
15741755
"""
15751756
Generate statistics for the current population.
@@ -1579,6 +1760,27 @@ def statistics(self) -> str:
15791760
"""
15801761
return population_statistics(self.population)
15811762

1763+
def _recompile_after_checkpoint(self) -> None:
1764+
"""Recompile kernel functions for population members after checkpoint load."""
1765+
recompile_failures: list[tuple[PopulationMember, str]] = []
1766+
for member in self.population:
1767+
if member.fn is _unset_fn and member.status == "ok":
1768+
try:
1769+
member.fn = self.kernel.compile_config(
1770+
member.config, allow_print=False
1771+
)
1772+
except Exception as e:
1773+
member.fn = _unset_fn
1774+
member.status = "error"
1775+
member.perfs.append(inf) # Ensure member won't be selected as best
1776+
recompile_failures.append((member, str(e)))
1777+
1778+
if recompile_failures:
1779+
self.log(
1780+
f"Warning: {len(recompile_failures)} config(s) failed to recompile "
1781+
f"and will be skipped. First failure: {recompile_failures[0][1]}"
1782+
)
1783+
15821784
def run_finishing_phase(
15831785
self, best: PopulationMember, rounds: int
15841786
) -> PopulationMember:

helion/autotuner/de_surrogate_hybrid.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,9 @@ def __init__(
135135
# Track all evaluations for surrogate training
136136
self.all_observations: list[tuple[FlatConfig, float]] = []
137137

138-
def _autotune(self) -> Config:
138+
def _init_search(self) -> None:
139139
"""
140-
Run DE with surrogate-assisted selection.
141-
142-
Returns:
143-
Best configuration found
140+
Initialize DE with surrogate-assisted selection.
144141
"""
145142
self.log("=" * 70)
146143
self.log("Differential Evolution with Surrogate-Assisted Selection")
@@ -174,8 +171,17 @@ def _autotune(self) -> Config:
174171
self.best_perf_history = [self.best.perf]
175172
self.generations_without_improvement = 0
176173

174+
self.set_generation(2)
175+
176+
def _autotune(self) -> Config:
177+
"""
178+
Run DE with surrogate-assisted selection.
179+
180+
Returns:
181+
Best configuration found
182+
"""
177183
# Evolution loop
178-
for gen in range(2, self.max_generations + 1):
184+
for gen in range(self._current_generation, self.max_generations + 1):
179185
self.set_generation(gen)
180186
self._evolve_generation(gen)
181187

helion/autotuner/differential_evolution.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def check_early_stopping(self) -> bool:
236236
self.generations_without_improvement = 0
237237
return False
238238

239-
def _autotune(self) -> Config:
239+
def _init_search(self) -> None:
240240
early_stopping_enabled = (
241241
self.min_improvement_delta is not None and self.patience is not None
242242
)
@@ -265,7 +265,14 @@ def _autotune(self) -> Config:
265265
self.best_perf_history = [self.best.perf]
266266
self.generations_without_improvement = 0
267267

268-
for i in range(2, self.max_generations):
268+
self.set_generation(2)
269+
270+
def _autotune(self) -> Config:
271+
early_stopping_enabled = (
272+
self.min_improvement_delta is not None and self.patience is not None
273+
)
274+
275+
for i in range(self._current_generation, self.max_generations):
269276
self.set_generation(i)
270277
self.log(f"Generation {i} starting")
271278
replaced = self.evolve_population()

0 commit comments

Comments
 (0)