Skip to content

Commit 0104696

Browse files
committed
[Autotuner] Auto-checkpoint feature and ability to resume from checkpoint
1 parent 9c6b4e0 commit 0104696

File tree

12 files changed

+2071
-109
lines changed

12 files changed

+2071
-109
lines changed

docs/api/settings.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,13 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
197197
198198
Users can still override individual ``autotune_*`` settings; explicit values win over the preset. Controlled by ``HELION_AUTOTUNE_EFFORT``.
199199
200+
.. autoattribute:: Settings.autotune_checkpoint_id
201+
202+
Checkpoint ID for resuming autotuning from a previous checkpoint. When set, the autotuner attempts to load
203+
state from a checkpoint file matching this ID, allowing long-running autotuning sessions to be interrupted
204+
and resumed. The checkpoint ID contains a hash prefix that identifies the kernel, hardware, and input shapes.
205+
If the hash doesn't match, a ``CheckpointError`` is raised.
206+
Controlled by ``HELION_AUTOTUNE_CHECKPOINT_ID``.
200207
201208
```
202209

@@ -295,6 +302,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
295302
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
296303
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
297304
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
305+
| ``HELION_AUTOTUNE_CHECKPOINT_ID`` | ``autotune_checkpoint_id`` | Checkpoint ID for resuming autotuning from a previous checkpoint. |
298306
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
299307
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
300308
| ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. |

docs/deployment_autotuning.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,30 @@ tuning time versus coverage, or try different search algorithms.
104104
need more reproducibility; see {doc}`api/settings`. Note this only
105105
affects which configs are tried, not the timing results.
106106

107+
### Checkpointing Long-Running Autotuning
108+
109+
For very long autotuning sessions, you can save and resume state using
110+
checkpoints. This is useful when tuning might be interrupted (e.g., preemptible
111+
instances) or when you want to continue tuning from a previous unfinished run.
112+
113+
The simplest approach is to use the `HELION_AUTOTUNE_CHECKPOINT_ID` environment
114+
variable. When autotuning runs, it periodically saves checkpoints and logs the
115+
checkpoint ID. To resume, set this environment variable to the checkpoint ID
116+
from a previous run.
117+
118+
```bash
119+
# First run - autotuning will log checkpoint IDs as it progresses:
120+
# "Checkpoint saved: .../autotuner_checkpoints/a1b2c3d4_1706123456_e5f6g7h8.checkpoint"
121+
# "To resume from this checkpoint, set HELION_AUTOTUNE_CHECKPOINT_ID=a1b2c3d4_1706123456_e5f6g7h8 ..."
122+
python run_kernel.py
123+
124+
# If interrupted, resume from the last checkpoint:
125+
HELION_AUTOTUNE_CHECKPOINT_ID=a1b2c3d4_1706123456_e5f6g7h8 python run_kernel.py
126+
```
127+
128+
The checkpoint ID contains a hash prefix that identifies the kernel, hardware,
129+
and input shapes. If the hash doesn't match, a `CheckpointError` is raised.
130+
107131
## Deploy a Single Config
108132

109133
If one configuration wins for every production call, bake it into the decorator:

helion/_testing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
import operator
1111
import os
1212
from pathlib import Path
13+
import random
1314
import re
1415
import sys
1516
from typing import TYPE_CHECKING
1617
from typing import Callable
1718
from typing import Generator
1819
import unittest
1920

21+
import numpy as np
2022
import pytest
2123
import torch
2224
from torch.utils._pytree import tree_map
@@ -40,6 +42,26 @@
4042
from .runtime.kernel import Kernel
4143

4244

45+
def seed_rng(seed: int) -> None:
46+
random.seed(seed)
47+
np.random.seed(seed) # noqa: NPY002
48+
torch.manual_seed(seed)
49+
50+
51+
@contextlib.contextmanager
52+
def fork_rng() -> Generator[None, None, None]:
53+
"""Context manager that forks all RNGs and restores original state on exit."""
54+
python_state = random.getstate()
55+
numpy_state = np.random.get_state() # noqa: NPY002
56+
57+
with torch.random.fork_rng():
58+
try:
59+
yield
60+
finally:
61+
random.setstate(python_state)
62+
np.random.set_state(numpy_state) # noqa: NPY002
63+
64+
4365
def _strip_launcher_args(value: str) -> str:
4466
strip_pairs = []
4567
if supports_amd_cdna_tunables():

0 commit comments

Comments
 (0)