Skip to content

Commit 4f24dc9

Browse files
committed
[Auto-Recovery] Refactor search algorithms for checkpointable state
Replace generator-based pattern search with a PatternSearchCopy dataclass (generators can't be pickled). Extract _init_search() from _autotune() in all search algorithms so checkpoint resume can skip initialization and jump directly into the search loop. Add _current_generation tracking to BaseSearch. Pure refactoring, no behavioral change. stack-info: PR: #1946, branch: yf225/stack/95
1 parent f28637d commit 4f24dc9

File tree

5 files changed

+260
-211
lines changed

5 files changed

+260
-211
lines changed

helion/autotuner/base_search.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
343343
self.args: Sequence[object] = args
344344
self.log = AutotuningLogger(self.settings)
345345
self.best_perf_so_far = inf
346+
self._current_generation = 0
346347
self._prepared = False
347348
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
348349
self._precompile_args_path: str | None = None
@@ -1089,6 +1090,8 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
10891090
torch.save(self.args, args_path)
10901091
self._precompile_args_path = args_path
10911092
exit_stack.callback(self.cleanup)
1093+
1094+
self._init_search()
10921095
try:
10931096
best = self._autotune()
10941097
finally:
@@ -1112,6 +1115,15 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
11121115
print(triton_code, file=sys.stderr)
11131116
return best
11141117

1118+
def _init_search(self) -> None:
1119+
"""
1120+
Initialize the search state for a fresh autotuning run.
1121+
1122+
Subclasses should override this to set up initial population and state.
1123+
After this method, _current_generation should be set to the generation
1124+
that _autotune() should start its loop from.
1125+
"""
1126+
11151127
def _autotune(self) -> Config:
11161128
"""
11171129
Abstract method to perform the actual autotuning.
@@ -1570,6 +1582,12 @@ def rebenchmark_population(
15701582
members = self.population
15711583
self.rebenchmark([p for p in members if self.should_rebenchmark(p)], desc=desc)
15721584

1585+
def set_generation(self, generation: int) -> None:
1586+
if generation == self._current_generation:
1587+
return
1588+
self._current_generation = generation
1589+
super().set_generation(generation)
1590+
15731591
def statistics(self) -> str:
15741592
"""
15751593
Generate statistics for the current population.

helion/autotuner/de_surrogate_hybrid.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,9 @@ def __init__(
135135
# Track all evaluations for surrogate training
136136
self.all_observations: list[tuple[FlatConfig, float]] = []
137137

138-
def _autotune(self) -> Config:
138+
def _init_search(self) -> None:
139139
"""
140-
Run DE with surrogate-assisted selection.
141-
142-
Returns:
143-
Best configuration found
140+
Initialize DE with surrogate-assisted selection.
144141
"""
145142
self.log("=" * 70)
146143
self.log("Differential Evolution with Surrogate-Assisted Selection")
@@ -174,8 +171,17 @@ def _autotune(self) -> Config:
174171
self.best_perf_history = [self.best.perf]
175172
self.generations_without_improvement = 0
176173

174+
self.set_generation(2)
175+
176+
def _autotune(self) -> Config:
177+
"""
178+
Run DE with surrogate-assisted selection.
179+
180+
Returns:
181+
Best configuration found
182+
"""
177183
# Evolution loop
178-
for gen in range(2, self.max_generations + 1):
184+
for gen in range(self._current_generation, self.max_generations + 1):
179185
self.set_generation(gen)
180186
self._evolve_generation(gen)
181187

helion/autotuner/differential_evolution.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def check_early_stopping(self) -> bool:
236236
self.generations_without_improvement = 0
237237
return False
238238

239-
def _autotune(self) -> Config:
239+
def _init_search(self) -> None:
240240
early_stopping_enabled = (
241241
self.min_improvement_delta is not None and self.patience is not None
242242
)
@@ -265,7 +265,14 @@ def _autotune(self) -> Config:
265265
self.best_perf_history = [self.best.perf]
266266
self.generations_without_improvement = 0
267267

268-
for i in range(2, self.max_generations):
268+
self.set_generation(2)
269+
270+
def _autotune(self) -> Config:
271+
early_stopping_enabled = (
272+
self.min_improvement_delta is not None and self.patience is not None
273+
)
274+
275+
for i in range(self._current_generation, self.max_generations):
269276
self.set_generation(i)
270277
self.log(f"Generation {i} starting")
271278
replaced = self.evolve_population()

helion/autotuner/pattern_search.py

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
import enum
45
import math
56
from typing import TYPE_CHECKING
@@ -11,7 +12,6 @@
1112
from .effort_profile import PATTERN_SEARCH_DEFAULTS
1213

1314
if TYPE_CHECKING:
14-
from collections.abc import Iterator
1515
from collections.abc import Sequence
1616

1717
from ..autotuner.effort_profile import AutotuneEffortProfile
@@ -31,6 +31,26 @@ class InitialPopulationStrategy(enum.Enum):
3131
"""Start from default config plus up to 20 best matching cached configs from previous runs."""
3232

3333

34+
@dataclasses.dataclass
35+
class PatternSearchCopy:
36+
"""
37+
Represents one copy of the pattern search.
38+
39+
Each copy explores from a different starting point. The `copies` parameter
40+
controls how many of these run in parallel.
41+
"""
42+
43+
# The current best member for this search copy.
44+
current: PopulationMember
45+
46+
# Whether this search copy has stopped (no more candidates or early stopping).
47+
stopped: bool = False
48+
49+
# Remaining patience for early stopping (decremented when no improvement).
50+
# None means no patience tracking (stop immediately on no improvement).
51+
patience_remaining: int | None = None
52+
53+
3454
class PatternSearch(PopulationBasedSearch):
3555
"""Search that explores single-parameter perturbations around the current best."""
3656

@@ -87,6 +107,8 @@ def __init__(
87107
self.num_neighbors_cap = num_neighbors_cap
88108
self.compile_timeout_lower_bound = compile_timeout_lower_bound
89109
self.compile_timeout_quantile = compile_timeout_quantile
110+
self.visited: set[Config] = set()
111+
self.search_copies: list[PatternSearchCopy] = []
90112

91113
@classmethod
92114
def get_kwargs_from_profile(
@@ -128,17 +150,18 @@ def _generate_initial_population_flat(self) -> list[FlatConfig]:
128150
return pop
129151
return self.config_gen.random_population_flat(self.initial_population)
130152

131-
def _autotune(self) -> Config:
132-
initial_population_name = self.initial_population_strategy.name
153+
def _init_search(self) -> None:
133154
self.log(
134-
f"Starting PatternSearch with initial_population={initial_population_name}, copies={self.copies}, max_generations={self.max_generations}"
155+
f"Starting {type(self).__name__} with initial_population={self.initial_population_strategy.name},"
156+
f" copies={self.copies},"
157+
f" max_generations={self.max_generations}"
135158
)
136-
visited: set[Config] = set()
159+
self.visited.clear()
137160
self.population = []
138161
for flat_config in self._generate_initial_population_flat():
139162
member = self.make_unbenchmarked(flat_config)
140-
if member.config not in visited:
141-
visited.add(member.config)
163+
if member.config not in self.visited:
164+
self.visited.add(member.config)
142165
self.population.append(member)
143166
self.parallel_benchmark_population(self.population, desc="Initial population")
144167

@@ -163,20 +186,36 @@ def _autotune(self) -> Config:
163186
if not starting_points:
164187
raise exc.NoConfigFound
165188

166-
search_copies = [self._pattern_search_from(m, visited) for m in starting_points]
167-
for generation in range(1, self.max_generations + 1):
189+
self.search_copies = [PatternSearchCopy(current=m) for m in starting_points]
190+
self.set_generation(1)
191+
192+
def _autotune(self) -> Config:
193+
for generation in range(self._current_generation, self.max_generations + 1):
194+
self.set_generation(generation)
168195
prior_best = self.best
169196
new_population = {id(prior_best): prior_best}
170197
num_neighbors = 0
171198
num_active = 0
172-
for search_copy in search_copies:
173-
added = next(search_copy, ())
174-
if added:
175-
assert len(added) > 1
176-
num_active += 1
177-
num_neighbors += len(added) - 1
178-
for member in added:
179-
new_population[id(member)] = member
199+
active_copies: list[tuple[PatternSearchCopy, list[PopulationMember]]] = []
200+
for search_copy in self.search_copies:
201+
if search_copy.stopped:
202+
continue
203+
candidates = [search_copy.current]
204+
for flat_config in self._generate_neighbors(
205+
search_copy.current.flat_values
206+
):
207+
new_member = self.make_unbenchmarked(flat_config)
208+
if new_member.config not in self.visited:
209+
self.visited.add(new_member.config)
210+
candidates.append(new_member)
211+
if len(candidates) <= 1:
212+
search_copy.stopped = True
213+
continue
214+
num_active += 1
215+
num_neighbors += len(candidates) - 1
216+
for member in candidates:
217+
new_population[id(member)] = member
218+
active_copies.append((search_copy, candidates))
180219
if num_active == 0:
181220
break
182221

@@ -189,46 +228,35 @@ def _autotune(self) -> Config:
189228
# compile any unbenchmarked members in parallel
190229
unbenchmarked = [m for m in self.population if len(m.perfs) == 0]
191230
if unbenchmarked:
192-
self.set_generation(generation)
193231
self.parallel_benchmark_population(
194232
unbenchmarked, desc=f"Generation {generation}:"
195233
)
196234
# higher-accuracy rebenchmark
197235
self.rebenchmark_population(
198236
self.population, desc=f"Generation {generation}: verifying top configs"
199237
)
238+
239+
# Update each search copy after rebenchmarking (uses refined perf values)
240+
for search_copy, candidates in active_copies:
241+
best = min(candidates, key=performance)
242+
if self._check_early_stopping(best, search_copy.current):
243+
if (
244+
search_copy.patience_remaining is not None
245+
and search_copy.patience_remaining > 0
246+
):
247+
search_copy.patience_remaining -= 1
248+
else:
249+
search_copy.stopped = True
250+
if not search_copy.stopped:
251+
search_copy.current = best
252+
200253
# Log final statistics for this generation
201254
self.log(f"Generation {generation} complete:", self.statistics)
202255

203256
# Run finishing phase to simplify the best configuration
204257
best = self.run_finishing_phase(self.best, self.finishing_rounds)
205258
return best.config
206259

207-
def _pattern_search_from(
208-
self, current: PopulationMember, visited: set[Config]
209-
) -> Iterator[list[PopulationMember]]:
210-
"""
211-
Run a single copy of pattern search from the given starting point.
212-
213-
We use a generator and yield the new population at each generation so that we can
214-
run multiple copies of pattern search in parallel.
215-
"""
216-
for _ in range(self.max_generations):
217-
candidates = [current]
218-
for flat_config in self._generate_neighbors(current.flat_values):
219-
new_member = self.make_unbenchmarked(flat_config)
220-
if new_member.config not in visited:
221-
visited.add(new_member.config)
222-
candidates.append(new_member)
223-
if len(candidates) <= 1:
224-
return # no new candidates, stop searching
225-
yield candidates # yield new population to benchmark in parallel
226-
# update search copy and check early stopping criteria
227-
best = min(candidates, key=performance)
228-
if self._check_early_stopping(best, current):
229-
return
230-
current = best
231-
232260
def _check_early_stopping(
233261
self, best: PopulationMember, current: PopulationMember
234262
) -> bool:

0 commit comments

Comments
 (0)