Skip to content

Commit 20726a8

Browse files
committed
[Autotuner] Auto-checkpoint feature and ability to resume from checkpoint
Fixes #1330. Internal customers had a lot of pain with IMA errors and they also feel that spawn mode is too much overhead causing autotuning time to be extra long. This PR stack adds an auto-recovery feature by checkpointing regularly (which is by itself useful for server crash scenarios mentioned in #1330) and then automatically start a new autotune process using previously saved checkpoint if there is an IMA error (next PR). stack-info: PR: #1920, branch: yf225/stack/90
1 parent ff57fe4 commit 20726a8

File tree

12 files changed

+1246
-212
lines changed

12 files changed

+1246
-212
lines changed

docs/api/settings.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
209209
Each preset also sets a default initial population strategy (see :doc:`../deployment_autotuning` for details).
210210
Users can still override individual ``autotune_*`` settings; explicit values win over the preset. Controlled by ``HELION_AUTOTUNE_EFFORT``.
211211
212+
.. autoattribute:: Settings.autotune_checkpoint_dir
213+
214+
Directory path for saving and resuming autotuning checkpoints. When set, the autotuner
215+
saves in-progress state to ``{dir}/{stable_hash}.pt`` and auto-discovers matching
216+
checkpoints on subsequent runs. The checkpoint file is deleted on successful completion.
217+
When unset (default), no checkpoints are saved or loaded (opt-in).
218+
Controlled by ``HELION_AUTOTUNE_CHECKPOINT_DIR``.
219+
212220
.. autoattribute:: Settings.autotune_best_available_max_configs
213221
214222
Maximum number of cached configs to use when seeding the initial population with the ``from_best_available`` strategy.
@@ -323,6 +331,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"LFBOTreeSearch"`` (default),
323331
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
324332
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
325333
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
334+
| ``HELION_AUTOTUNE_CHECKPOINT_DIR`` | ``autotune_checkpoint_dir`` | Directory path for saving/resuming autotuning checkpoints (opt-in). |
326335
| ``TRITON_STORE_BINARY_ONLY`` | Triton (autotuning) | Set to ``1`` during autotuning to skip Triton intermediate IRs, reducing cache size ~40%. Set to ``0`` to retain IRs for debugging. |
327336
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
328337
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, skip both reading and writing the autotuning cache entirely. |

docs/deployment_autotuning.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,29 @@ Related settings for `from_best_available` (see {doc}`api/settings`):
183183
| `autotune_best_available_max_configs` | `HELION_BEST_AVAILABLE_MAX_CONFIGS` | 20 | Maximum cached configs to seed |
184184
| `autotune_best_available_max_cache_scan` | `HELION_BEST_AVAILABLE_MAX_CACHE_SCAN` | 500 | Maximum cache files to scan |
185185

186+
### Checkpointing Long-Running Autotuning
187+
188+
For very long autotuning sessions, you can save and resume state using
189+
checkpoints. This is useful when tuning might be interrupted (e.g., preemptible
190+
instances) or when you want to continue tuning from a previous unfinished run.
191+
192+
Set the `HELION_AUTOTUNE_CHECKPOINT_DIR` environment variable to a directory
193+
path. The autotuner will periodically save checkpoints there, keyed by the
194+
kernel's stable hash. If interrupted, re-run with the same directory to resume
195+
automatically. On successful completion, the checkpoint file is cleaned up.
196+
197+
```bash
198+
# Enable checkpointing to a directory:
199+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
200+
201+
# If interrupted, just re-run with the same directory to resume:
202+
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
203+
```
204+
205+
Without `HELION_AUTOTUNE_CHECKPOINT_DIR`, no checkpoints are saved (opt-in).
206+
Multiple kernels can safely use the same directory — each kernel writes to a
207+
file named by its unique stable hash.
208+
186209
## Deploy a Single Config
187210

188211
If one configuration wins for every production call, bake it into the decorator:

helion/_testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from .runtime.kernel import Kernel
5858

5959

60+
6061
def _strip_launcher_args(value: str) -> str:
6162
strip_pairs = []
6263
if supports_amd_cdna_tunables():

helion/autotuner/base_search.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import math
1313
from math import inf
1414
import os
15+
from pathlib import Path
16+
import pickle
1517
import pprint
1618
import random
1719
import re
@@ -343,6 +345,15 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
343345
self.args: Sequence[object] = args
344346
self.log = AutotuningLogger(self.settings)
345347
self.best_perf_so_far = inf
348+
self._current_generation = 0
349+
self.counters: collections.Counter[str] = collections.Counter()
350+
self._autotune_metrics: AutotuneMetrics = AutotuneMetrics(
351+
kernel_name="",
352+
input_shapes="",
353+
hardware="",
354+
random_seed=0,
355+
search_algorithm=type(self).__name__,
356+
)
346357
self._prepared = False
347358
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
348359
self._precompile_args_path: str | None = None
@@ -406,6 +417,90 @@ def cleanup(self) -> None:
406417
self._precompile_args_path = None
407418
self._precompile_result_counter = count()
408419

420+
# Fields excluded from pickle checkpoints: unpicklable infrastructure,
421+
# fields recomputed by _prepare(), and fields loaded separately.
422+
_CHECKPOINT_EXCLUDE = frozenset(
423+
{
424+
# Unpicklable infrastructure
425+
"kernel",
426+
"args",
427+
"log",
428+
"settings",
429+
"config_spec",
430+
"_precompile_tmpdir",
431+
"_precompile_args_path",
432+
"_precompile_result_counter",
433+
# Recomputed by _prepare() before checkpoint load
434+
"_baseline_output",
435+
"_baseline_post_args",
436+
"_mutated_arg_indices",
437+
"_effective_atol",
438+
"_effective_rtol",
439+
"_jobs",
440+
"_autotune_metrics",
441+
"_prepared",
442+
"_skip_cache",
443+
# Loaded separately via _load_crashed_configs()
444+
"_crashed_config_strs",
445+
}
446+
)
447+
448+
def __getstate__(self) -> dict[str, Any]:
449+
return {
450+
k: v for k, v in self.__dict__.items() if k not in self._CHECKPOINT_EXCLUDE
451+
}
452+
453+
_stable_hash: str | None = None
454+
455+
def _get_stable_hash(self) -> str:
456+
"""Get the full stable hash for this kernel's cache key (cached)."""
457+
if self._stable_hash is None:
458+
from .local_cache import LocalAutotuneCache
459+
460+
self._stable_hash = LocalAutotuneCache(self)._generate_key().stable_hash()
461+
return self._stable_hash
462+
463+
def _try_load_checkpoint(self) -> bool:
464+
"""Attempt to load checkpoint from checkpoint dir. Returns True if successful."""
465+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
466+
if checkpoint_dir_str is None:
467+
return False
468+
469+
checkpoint_dir = Path(checkpoint_dir_str)
470+
stable_hash = self._get_stable_hash()
471+
checkpoint_file = checkpoint_dir / f"{stable_hash}.pt"
472+
473+
if not checkpoint_file.exists():
474+
return False # No matching checkpoint; start fresh
475+
476+
# Matching file exists, attempt to load
477+
self.log(f"Resuming from checkpoint: {checkpoint_file}")
478+
try:
479+
with open(checkpoint_file, "rb") as f:
480+
loaded = pickle.load(f)
481+
except Exception as e:
482+
raise exc.CheckpointError(
483+
f"Failed to load checkpoint file '{checkpoint_file}': {e}\n"
484+
f"The file may be corrupted. Delete it to start fresh."
485+
) from e
486+
487+
# Validate stable hash matches (guards against renamed/copied files)
488+
loaded_hash = getattr(loaded, "_stable_hash", None)
489+
if loaded_hash is not None and loaded_hash != self._get_stable_hash():
490+
raise exc.CheckpointError(
491+
"Checkpoint is incompatible: kernel, hardware, or input shapes "
492+
"may have changed."
493+
)
494+
495+
# Copy loaded search state into self (self already has kernel, args,
496+
# log, etc. from __init__ and _prepare())
497+
self.__dict__.update(loaded.__dict__)
498+
self._recompile_after_checkpoint()
499+
self.log(f"Resumed at generation {self._current_generation}")
500+
return True
501+
502+
def _recompile_after_checkpoint(self) -> None:
503+
"""Recompile after loading a checkpoint. Override in subclasses."""
409504
def _compute_baseline(
410505
self,
411506
) -> tuple[object, Sequence[int], Sequence[object] | None]:
@@ -629,6 +724,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
629724
The performance of the configuration in ms.
630725
"""
631726
self._autotune_metrics.num_configs_tested += 1
727+
self.counters["benchmark"] += 1
632728
self.log.debug(lambda: f"Running benchmark for {config!r}")
633729
_captured_output: list[str] = [""]
634730
_capture_ctx = (
@@ -1089,8 +1185,12 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10891185
torch.save(self.args, args_path)
10901186
self._precompile_args_path = args_path
10911187
exit_stack.callback(self.cleanup)
1188+
1189+
if not self._try_load_checkpoint():
1190+
self._init_search()
10921191
try:
10931192
best = self._autotune()
1193+
self._cleanup_checkpoint()
10941194
finally:
10951195
self._finalize_autotune_metrics()
10961196
end = time.perf_counter()
@@ -1112,6 +1212,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11121212
print(triton_code, file=sys.stderr)
11131213
return best
11141214

1215+
def _init_search(self) -> None:
1216+
"""
1217+
Initialize the search state for a fresh autotuning run.
1218+
1219+
This method is called when starting autotuning without a checkpoint.
1220+
Subclasses should override this to set up initial population and state.
1221+
After this method, _current_generation should be set to the generation
1222+
that _autotune() should start its loop from.
1223+
"""
1224+
11151225
def _autotune(self) -> Config:
11161226
"""
11171227
Abstract method to perform the actual autotuning.
@@ -1123,6 +1233,68 @@ def _autotune(self) -> Config:
11231233
"""
11241234
raise NotImplementedError
11251235

1236+
def save_checkpoint(self) -> Path | None:
1237+
"""
1238+
Save current autotuner state to checkpoint file.
1239+
1240+
Only saves when autotune_checkpoint_dir is set (opt-in).
1241+
Overwrites the same file each generation (keyed by stable hash).
1242+
Uses pickle to serialize the entire autotuner object (minus unpicklable
1243+
fields excluded by __getstate__).
1244+
1245+
Returns:
1246+
Path to saved checkpoint file, or None if not saved
1247+
"""
1248+
from ..runtime.kernel import BoundKernel
1249+
1250+
# External kernels don't support caching/checkpointing
1251+
if not isinstance(self.kernel, BoundKernel):
1252+
return None
1253+
1254+
if not self.kernel.is_cacheable():
1255+
return None
1256+
1257+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1258+
if checkpoint_dir_str is None:
1259+
return None # Opt-in: no dir set, no saving
1260+
1261+
stable_hash = self._get_stable_hash()
1262+
checkpoint_dir = Path(checkpoint_dir_str)
1263+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
1264+
checkpoint_path = checkpoint_dir / f"{stable_hash}.pt"
1265+
1266+
# Atomic write using temp file + rename
1267+
tmp = checkpoint_dir / f".tmp.{stable_hash}.{os.getpid()}"
1268+
with open(tmp, "wb") as f:
1269+
pickle.dump(self, f)
1270+
os.replace(tmp, checkpoint_path)
1271+
1272+
self.log(f"Checkpoint saved: {checkpoint_path}")
1273+
return checkpoint_path
1274+
1275+
def _cleanup_checkpoint(self) -> None:
1276+
"""Delete checkpoint file on successful autotune completion.
1277+
1278+
Checkpoints are ephemeral in-progress state. Once autotuning
1279+
completes successfully, the result is cached normally and the
1280+
checkpoint is no longer needed.
1281+
"""
1282+
checkpoint_dir_str = self.settings.autotune_checkpoint_dir
1283+
if checkpoint_dir_str is None:
1284+
return
1285+
1286+
stable_hash = self._get_stable_hash()
1287+
checkpoint_file = Path(checkpoint_dir_str) / f"{stable_hash}.pt"
1288+
if checkpoint_file.exists():
1289+
checkpoint_file.unlink()
1290+
self.log(f"Checkpoint cleaned up: {checkpoint_file}")
1291+
1292+
# Clean up crash-recovery artifacts
1293+
for suffix in (".pending_config", ".crashed_configs"):
1294+
artifact = Path(checkpoint_dir_str) / f"{stable_hash}{suffix}"
1295+
if artifact.exists():
1296+
artifact.unlink()
1297+
11261298
def set_generation(self, generation: int) -> None:
11271299
self._autotune_metrics.num_generations = generation
11281300

@@ -1177,6 +1349,15 @@ class PopulationMember:
11771349
def perf(self) -> float:
11781350
return self.perfs[-1]
11791351

1352+
def __getstate__(self) -> dict[str, Any]:
1353+
state = self.__dict__.copy()
1354+
state["fn"] = None # compiled functions are not picklable
1355+
return state
1356+
1357+
def __setstate__(self, state: dict[str, Any]) -> None:
1358+
self.__dict__.update(state)
1359+
self.fn = _unset_fn
1360+
11801361

11811362
def performance(member: PopulationMember) -> float:
11821363
"""
@@ -1570,6 +1751,14 @@ def rebenchmark_population(
15701751
members = self.population
15711752
self.rebenchmark([p for p in members if self.should_rebenchmark(p)], desc=desc)
15721753

1754+
def set_generation(self, generation: int) -> None:
1755+
if generation == self._current_generation:
1756+
return
1757+
self._current_generation = generation
1758+
super().set_generation(generation)
1759+
if generation > 0:
1760+
self.save_checkpoint()
1761+
15731762
def statistics(self) -> str:
15741763
"""
15751764
Generate statistics for the current population.
@@ -1579,6 +1768,27 @@ def statistics(self) -> str:
15791768
"""
15801769
return population_statistics(self.population)
15811770

1771+
def _recompile_after_checkpoint(self) -> None:
1772+
"""Recompile kernel functions for population members after checkpoint load."""
1773+
recompile_failures: list[tuple[PopulationMember, str]] = []
1774+
for member in self.population:
1775+
if member.fn is _unset_fn and member.status == "ok":
1776+
try:
1777+
member.fn = self.kernel.compile_config(
1778+
member.config, allow_print=False
1779+
)
1780+
except Exception as e:
1781+
member.fn = _unset_fn
1782+
member.status = "error"
1783+
member.perfs.append(inf) # Ensure member won't be selected as best
1784+
recompile_failures.append((member, str(e)))
1785+
1786+
if recompile_failures:
1787+
self.log(
1788+
f"Warning: {len(recompile_failures)} config(s) failed to recompile "
1789+
f"and will be skipped. First failure: {recompile_failures[0][1]}"
1790+
)
1791+
15821792
def run_finishing_phase(
15831793
self, best: PopulationMember, rounds: int
15841794
) -> PopulationMember:

helion/autotuner/de_surrogate_hybrid.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,9 @@ def __init__(
135135
# Track all evaluations for surrogate training
136136
self.all_observations: list[tuple[FlatConfig, float]] = []
137137

138-
def _autotune(self) -> Config:
138+
def _init_search(self) -> None:
139139
"""
140-
Run DE with surrogate-assisted selection.
141-
142-
Returns:
143-
Best configuration found
140+
Initialize DE with surrogate-assisted selection.
144141
"""
145142
self.log("=" * 70)
146143
self.log("Differential Evolution with Surrogate-Assisted Selection")
@@ -174,8 +171,17 @@ def _autotune(self) -> Config:
174171
self.best_perf_history = [self.best.perf]
175172
self.generations_without_improvement = 0
176173

174+
self.set_generation(2)
175+
176+
def _autotune(self) -> Config:
177+
"""
178+
Run DE with surrogate-assisted selection.
179+
180+
Returns:
181+
Best configuration found
182+
"""
177183
# Evolution loop
178-
for gen in range(2, self.max_generations + 1):
184+
for gen in range(self._current_generation, self.max_generations + 1):
179185
self.set_generation(gen)
180186
self._evolve_generation(gen)
181187

0 commit comments

Comments
 (0)