Skip to content

Commit 16d3fb7

Browse files
committed
[Autotuner] Auto-checkpoint feature and ability to resume from checkpoint
Fixes #1330. Internal customers had a lot of pain with IMA errors and they also feel that spawn mode is too much overhead causing autotuning time to be extra long. This PR stack adds an auto-recovery feature by checkpointing regularly (which is by itself useful for server crash scenarios mentioned in #1330) and then automatically start a new autotune process using previously saved checkpoint if there is an IMA error (next PR). stack-info: PR: #1920, branch: yf225/stack/90
1 parent ff57fe4 commit 16d3fb7

File tree

12 files changed

+1239
-212
lines changed

12 files changed

+1239
-212
lines changed

docs/api/settings.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
209209
Each preset also sets a default initial population strategy (see :doc:`../deployment_autotuning` for details).
210210
Users can still override individual ``autotune_*`` settings; explicit values win over the preset. Controlled by ``HELION_AUTOTUNE_EFFORT``.
211211
212+
.. autoattribute:: Settings.autotune_checkpoint_dir
213+
214+
Directory path for saving and resuming autotuning checkpoints. When set, the autotuner
215+
saves in-progress state to ``{dir}/{stable_hash}.pt`` and auto-discovers matching
216+
checkpoints on subsequent runs. The checkpoint file is deleted on successful completion.
217+
When unset (default), no checkpoints are saved or loaded (opt-in).
218+
Controlled by ``HELION_AUTOTUNE_CHECKPOINT_DIR``.
219+
212220
.. autoattribute:: Settings.autotune_best_available_max_configs
213221
214222
Maximum number of cached configs to use when seeding the initial population with the ``from_best_available`` strategy.
@@ -323,6 +331,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"LFBOTreeSearch"`` (default),
323331
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
324332
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
325333
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
334+
| ``HELION_AUTOTUNE_CHECKPOINT_DIR`` | ``autotune_checkpoint_dir`` | Directory path for saving/resuming autotuning checkpoints (opt-in). |
326335
| ``TRITON_STORE_BINARY_ONLY`` | Triton (autotuning) | Set to ``1`` during autotuning to skip Triton intermediate IRs, reducing cache size ~40%. Set to ``0`` to retain IRs for debugging. |
327336
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
328337
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, skip both reading and writing the autotuning cache entirely. |

docs/deployment_autotuning.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,29 @@ Related settings for `from_best_available` (see {doc}`api/settings`):
183183
| `autotune_best_available_max_configs` | `HELION_BEST_AVAILABLE_MAX_CONFIGS` | 20 | Maximum cached configs to seed |
184184
| `autotune_best_available_max_cache_scan` | `HELION_BEST_AVAILABLE_MAX_CACHE_SCAN` | 500 | Maximum cache files to scan |
185185

186+
### Checkpointing Long-Running Autotuning
187+
188+
For very long autotuning sessions, you can save and resume state using
189+
checkpoints. This is useful when tuning might be interrupted (e.g., preemptible
190+
instances) or when you want to continue tuning from a previous unfinished run.
191+
192+
Set the `HELION_AUTOTUNE_CHECKPOINT_DIR` environment variable to a directory
193+
path. The autotuner will periodically save checkpoints there, keyed by the
194+
kernel's stable hash. If interrupted, re-run with the same directory to resume
195+
automatically. On successful completion, the checkpoint file is cleaned up.
196+
197+
```bash
198+
# Enable checkpointing to a directory:
199+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
200+
201+
# If interrupted, just re-run with the same directory to resume:
202+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
203+
```
204+
205+
Without `HELION_AUTOTUNE_CHECKPOINT_DIR`, no checkpoints are saved (opt-in).
206+
Multiple kernels can safely use the same directory — each kernel writes to a
207+
file named by its unique stable hash.
208+
186209
## Deploy a Single Config
187210

188211
If one configuration wins for every production call, bake it into the decorator:

helion/_testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from .runtime.kernel import Kernel
5858

5959

60+
6061
def _strip_launcher_args(value: str) -> str:
6162
strip_pairs = []
6263
if supports_amd_cdna_tunables():

helion/autotuner/base_search.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import math
1313
from math import inf
1414
import os
15+
from pathlib import Path
16+
import pickle
1517
import pprint
1618
import random
1719
import re
@@ -343,6 +345,8 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
343345
self.args: Sequence[object] = args
344346
self.log = AutotuningLogger(self.settings)
345347
self.best_perf_so_far = inf
348+
self._current_generation = 0
349+
self.counters: collections.Counter[str] = collections.Counter()
346350
self._prepared = False
347351
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
348352
self._precompile_args_path: str | None = None
@@ -406,6 +410,90 @@ def cleanup(self) -> None:
406410
self._precompile_args_path = None
407411
self._precompile_result_counter = count()
408412

413+
# Fields excluded from pickle checkpoints: unpicklable infrastructure,
414+
# fields recomputed by _prepare(), and fields loaded separately.
415+
_CHECKPOINT_EXCLUDE = frozenset(
416+
{
417+
# Unpicklable infrastructure
418+
"kernel",
419+
"args",
420+
"log",
421+
"settings",
422+
"config_spec",
423+
"_precompile_tmpdir",
424+
"_precompile_args_path",
425+
"_precompile_result_counter",
426+
# Recomputed by _prepare() before checkpoint load
427+
"_baseline_output",
428+
"_baseline_post_args",
429+
"_mutated_arg_indices",
430+
"_effective_atol",
431+
"_effective_rtol",
432+
"_jobs",
433+
"_autotune_metrics",
434+
"_prepared",
435+
"_skip_cache",
436+
# Loaded separately via _load_crashed_configs()
437+
"_crashed_config_strs",
438+
}
439+
)
440+
441+
def __getstate__(self) -> dict[str, Any]:
442+
return {
443+
k: v for k, v in self.__dict__.items() if k not in self._CHECKPOINT_EXCLUDE
444+
}
445+
446+
_stable_hash: str | None = None
447+
448+
def _get_stable_hash(self) -> str:
449+
"""Get the full stable hash for this kernel's cache key (cached)."""
450+
if self._stable_hash is None:
451+
from .local_cache import LocalAutotuneCache
452+
453+
self._stable_hash = LocalAutotuneCache(self)._generate_key().stable_hash()
454+
return self._stable_hash
455+
456+
def _try_load_checkpoint(self) -> bool:
457+
"""Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
458+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
459+
if checkpoint_dir_str is None:
460+
return False
461+
462+
checkpoint_dir = Path(checkpoint_dir_str)
463+
stable_hash = self._get_stable_hash()
464+
checkpoint_file = checkpoint_dir / f"{stable_hash}.pt"
465+
466+
if not checkpoint_file.exists():
467+
return False # No matching checkpoint; start fresh
468+
469+
# Matching file exists, attempt to load
470+
self.log(f"Resuming from checkpoint: {checkpoint_file}")
471+
try:
472+
with open(checkpoint_file, "rb") as f:
473+
loaded = pickle.load(f)
474+
except Exception as e:
475+
raise exc.CheckpointError(
476+
f"Failed to load checkpoint file '{checkpoint_file}': {e}\n"
477+
f"The file may be corrupted. Delete it to start fresh."
478+
) from e
479+
480+
# Validate stable hash matches (guards against renamed/copied files)
481+
loaded_hash = getattr(loaded, "_stable_hash", None)
482+
if loaded_hash is not None and loaded_hash != self._get_stable_hash():
483+
raise exc.CheckpointError(
484+
"Checkpoint is incompatible: kernel, hardware, or input shapes "
485+
"may have changed."
486+
)
487+
488+
# Copy loaded search state into self (self already has kernel, args,
489+
# log, etc. from __init__ and _prepare())
490+
self.__dict__.update(loaded.__dict__)
491+
self._recompile_after_checkpoint()
492+
self.log(f"Resumed at generation {self._current_generation}")
493+
return True
494+
495+
def _recompile_after_checkpoint(self) -> None:
496+
"""Recompile after loading a checkpoint. Override in subclasses."""
409497
def _compute_baseline(
410498
self,
411499
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -629,6 +717,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
629717
The performance of the configuration in ms.
630718
"""
631719
self._autotune_metrics.num_configs_tested += 1
720+
self.counters["benchmark"] += 1
632721
self.log.debug(lambda: f"Running benchmark for {config!r}")
633722
_captured_output: list[str] = [""]
634723
_capture_ctx = (
@@ -1089,8 +1178,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10891178
torch.save(self.args, args_path)
10901179
self._precompile_args_path = args_path
10911180
exit_stack.callback(self.cleanup)
1181+
1182+
if not self._try_load_checkpoint():
1183+
self._init_search()
10921184
try:
10931185
best = self._autotune()
1186+
self._cleanup_checkpoint()
10941187
finally:
10951188
self._finalize_autotune_metrics()
10961189
end = time.perf_counter()
@@ -1112,6 +1205,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11121205
print(triton_code, file=sys.stderr)
11131206
return best
11141207

1208+
def _init_search(self) -> None:
1209+
"""
1210+
Initialize the search state for a fresh autotuning run.
1211+
1212+
This method is called when starting autotuning without a checkpoint.
1213+
Subclasses should override this to set up initial population and state.
1214+
After this method, _current_generation should be set to the generation
1215+
that _autotune() should start its loop from.
1216+
"""
1217+
11151218
def _autotune(self) -> Config:
11161219
"""
11171220
Abstract method to perform the actual autotuning.
@@ -1123,6 +1226,68 @@ def _autotune(self) -> Config:
11231226
"""
11241227
raise NotImplementedError
11251228

1229+
def save_checkpoint(self) -> Path | None:
1230+
"""
1231+
Save current autotuner state to checkpoint file.
1232+
1233+
Only saves when autotune_checkpoint_dir is set (opt-in).
1234+
Overwrites the same file each generation (keyed by stable hash).
1235+
Uses pickle to serialize the entire autotuner object (minus unpicklable
1236+
fields excluded by __getstate__).
1237+
1238+
Returns:
1239+
Path to saved checkpoint file, or None if not saved
1240+
"""
1241+
from ..runtime.kernel import BoundKernel
1242+
1243+
# External kernels don't support caching/checkpointing
1244+
if not isinstance(self.kernel, BoundKernel):
1245+
return None
1246+
1247+
if not self.kernel.is_cacheable():
1248+
return None
1249+
1250+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1251+
if checkpoint_dir_str is None:
1252+
return None # Opt-in: no dir set, no saving
1253+
1254+
stable_hash = self._get_stable_hash()
1255+
checkpoint_dir = Path(checkpoint_dir_str)
1256+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
1257+
checkpoint_path = checkpoint_dir / f"{stable_hash}.pt"
1258+
1259+
# Atomic write using temp file + rename
1260+
tmp = checkpoint_dir / f".tmp.{stable_hash}.{os.getpid()}"
1261+
with open(tmp, "wb") as f:
1262+
pickle.dump(self, f)
1263+
os.replace(tmp, checkpoint_path)
1264+
1265+
self.log(f"Checkpoint saved: {checkpoint_path}")
1266+
return checkpoint_path
1267+
1268+
def _cleanup_checkpoint(self) -> None:
1269+
"""Delete checkpoint file on successful autotune completion.
1270+
1271+
Checkpoints are ephemeral in-progress state. Once autotuning
1272+
completes successfully, the result is cached normally and the
1273+
checkpoint is no longer needed.
1274+
"""
1275+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1276+
if checkpoint_dir_str is None:
1277+
return
1278+
1279+
stable_hash = self._get_stable_hash()
1280+
checkpoint_file = Path(checkpoint_dir_str) / f"{stable_hash}.pt"
1281+
if checkpoint_file.exists():
1282+
checkpoint_file.unlink()
1283+
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
1284+
1285+
# Clean up crash-recovery artifacts
1286+
for suffix in (".pending_config", ".crashed_configs"):
1287+
artifact = Path(checkpoint_dir_str) / f"{stable_hash}{suffix}"
1288+
if artifact.exists():
1289+
artifact.unlink()
1290+
11261291
def set_generation(self, generation: int) -> None:
11271292
self._autotune_metrics.num_generations = generation
11281293

@@ -1177,6 +1342,15 @@ class PopulationMember:
11771342
def perf(self) -> float:
11781343
return self.perfs[-1]
11791344

1345+
def __getstate__(self) -> dict[str, Any]:
1346+
state = self.__dict__.copy()
1347+
state["fn"] = None # compiled functions are not picklable
1348+
return state
1349+
1350+
def __setstate__(self, state: dict[str, Any]) -> None:
1351+
self.__dict__.update(state)
1352+
self.fn = _unset_fn
1353+
11801354

11811355
def performance(member: PopulationMember) -> float:
11821356
"""
@@ -1570,6 +1744,14 @@ def rebenchmark_population(
15701744
members = self.population
15711745
self.rebenchmark([p for p in members if self.should_rebenchmark(p)], desc=desc)
15721746

1747+
def set_generation(self, generation: int) -> None:
1748+
if generation == self._current_generation:
1749+
return
1750+
self._current_generation = generation
1751+
super().set_generation(generation)
1752+
if generation > 0:
1753+
self.save_checkpoint()
1754+
15731755
def statistics(self) -> str:
15741756
"""
15751757
Generate statistics for the current population.
@@ -1579,6 +1761,27 @@ def statistics(self) -> str:
15791761
"""
15801762
return population_statistics(self.population)
15811763

1764+
def _recompile_after_checkpoint(self) -> None:
1765+
"""Recompile kernel functions for population members after checkpoint load."""
1766+
recompile_failures: list[tuple[PopulationMember, str]] = []
1767+
for member in self.population:
1768+
if member.fn is _unset_fn and member.status == "ok":
1769+
try:
1770+
member.fn = self.kernel.compile_config(
1771+
member.config, allow_print=False
1772+
)
1773+
except Exception as e:
1774+
member.fn = _unset_fn
1775+
member.status = "error"
1776+
member.perfs.append(inf) # Ensure member won't be selected as best
1777+
recompile_failures.append((member, str(e)))
1778+
1779+
if recompile_failures:
1780+
self.log(
1781+
f"Warning: {len(recompile_failures)} config(s) failed to recompile "
1782+
f"and will be skipped. First failure: {recompile_failures[0][1]}"
1783+
)
1784+
15821785
def run_finishing_phase(
15831786
self, best: PopulationMember, rounds: int
15841787
) -> PopulationMember:

helion/autotuner/de_surrogate_hybrid.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,9 @@ def __init__(
135135
# Track all evaluations for surrogate training
136136
self.all_observations: list[tuple[FlatConfig, float]] = []
137137

138-
def _autotune(self) -> Config:
138+
def _init_search(self) -> None:
139139
"""
140-
Run DE with surrogate-assisted selection.
141-
142-
Returns:
143-
Best configuration found
140+
Initialize DE with surrogate-assisted selection.
144141
"""
145142
self.log("=" * 70)
146143
self.log("Differential Evolution with Surrogate-Assisted Selection")
@@ -174,8 +171,17 @@ def _autotune(self) -> Config:
174171
self.best_perf_history = [self.best.perf]
175172
self.generations_without_improvement = 0
176173

174+
self.set_generation(2)
175+
176+
def _autotune(self) -> Config:
177+
"""
178+
Run DE with surrogate-assisted selection.
179+
180+
Returns:
181+
Best configuration found
182+
"""
177183
# Evolution loop
178-
for gen in range(2, self.max_generations + 1):
184+
for gen in range(self._current_generation, self.max_generations + 1):
179185
self.set_generation(gen)
180186
self._evolve_generation(gen)
181187

helion/autotuner/differential_evolution.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def check_early_stopping(self) -> bool:
236236
self.generations_without_improvement = 0
237237
return False
238238

239-
def _autotune(self) -> Config:
239+
def _init_search(self) -> None:
240240
early_stopping_enabled = (
241241
self.min_improvement_delta is not None and self.patience is not None
242242
)
@@ -265,7 +265,14 @@ def _autotune(self) -> Config:
265265
self.best_perf_history = [self.best.perf]
266266
self.generations_without_improvement = 0
267267

268-
for i in range(2, self.max_generations):
268+
self.set_generation(2)
269+
270+
def _autotune(self) -> Config:
271+
early_stopping_enabled = (
272+
self.min_improvement_delta is not None and self.patience is not None
273+
)
274+
275+
for i in range(self._current_generation, self.max_generations):
269276
self.set_generation(i)
270277
self.log(f"Generation {i} starting")
271278
replaced = self.evolve_population()

0 commit comments

Comments
 (0)