Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
self.args: Sequence[object] = args
self.log = AutotuningLogger(self.settings)
self.best_perf_so_far = inf
self._current_generation = 0
self._prepared = False
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
self._precompile_args_path: str | None = None
Expand Down Expand Up @@ -1089,6 +1090,8 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
torch.save(self.args, args_path)
self._precompile_args_path = args_path
exit_stack.callback(self.cleanup)

self._init_search()
try:
best = self._autotune()
finally:
Expand All @@ -1112,6 +1115,15 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
print(triton_code, file=sys.stderr)
return best

def _init_search(self) -> None:
"""
Initialize the search state for a fresh autotuning run.

Subclasses should override this to set up initial population and state.
After this method, _current_generation should be set to the generation
that _autotune() should start its loop from.
"""

def _autotune(self) -> Config:
"""
Abstract method to perform the actual autotuning.
Expand Down Expand Up @@ -1570,6 +1582,12 @@ def rebenchmark_population(
members = self.population
self.rebenchmark([p for p in members if self.should_rebenchmark(p)], desc=desc)

def set_generation(self, generation: int) -> None:
if generation == self._current_generation:
return
self._current_generation = generation
super().set_generation(generation)

def statistics(self) -> str:
"""
Generate statistics for the current population.
Expand Down
18 changes: 12 additions & 6 deletions helion/autotuner/de_surrogate_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,9 @@ def __init__(
# Track all evaluations for surrogate training
self.all_observations: list[tuple[FlatConfig, float]] = []

def _autotune(self) -> Config:
def _init_search(self) -> None:
"""
Run DE with surrogate-assisted selection.

Returns:
Best configuration found
Initialize DE with surrogate-assisted selection.
"""
self.log("=" * 70)
self.log("Differential Evolution with Surrogate-Assisted Selection")
Expand Down Expand Up @@ -174,8 +171,17 @@ def _autotune(self) -> Config:
self.best_perf_history = [self.best.perf]
self.generations_without_improvement = 0

self.set_generation(2)

def _autotune(self) -> Config:
"""
Run DE with surrogate-assisted selection.

Returns:
Best configuration found
"""
# Evolution loop
for gen in range(2, self.max_generations + 1):
for gen in range(self._current_generation, self.max_generations + 1):
self.set_generation(gen)
self._evolve_generation(gen)

Expand Down
11 changes: 9 additions & 2 deletions helion/autotuner/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def check_early_stopping(self) -> bool:
self.generations_without_improvement = 0
return False

def _autotune(self) -> Config:
def _init_search(self) -> None:
early_stopping_enabled = (
self.min_improvement_delta is not None and self.patience is not None
)
Expand Down Expand Up @@ -265,7 +265,14 @@ def _autotune(self) -> Config:
self.best_perf_history = [self.best.perf]
self.generations_without_improvement = 0

for i in range(2, self.max_generations):
self.set_generation(2)

def _autotune(self) -> Config:
early_stopping_enabled = (
self.min_improvement_delta is not None and self.patience is not None
)

for i in range(self._current_generation, self.max_generations):
self.set_generation(i)
self.log(f"Generation {i} starting")
replaced = self.evolve_population()
Expand Down
114 changes: 71 additions & 43 deletions helion/autotuner/pattern_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
import enum
import math
from typing import TYPE_CHECKING
Expand All @@ -11,7 +12,6 @@
from .effort_profile import PATTERN_SEARCH_DEFAULTS

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Sequence

from ..autotuner.effort_profile import AutotuneEffortProfile
Expand All @@ -31,6 +31,26 @@ class InitialPopulationStrategy(enum.Enum):
"""Start from default config plus up to 20 best matching cached configs from previous runs."""


@dataclasses.dataclass
class PatternSearchCopy:
"""
Represents one copy of the pattern search.

Each copy explores from a different starting point. The `copies` parameter
controls how many of these run in parallel.
"""

# The current best member for this search copy.
current: PopulationMember

# Whether this search copy has stopped (no more candidates or early stopping).
stopped: bool = False

# Remaining patience for early stopping (decremented when no improvement).
# None means no patience tracking (stop immediately on no improvement).
patience_remaining: int | None = None


class PatternSearch(PopulationBasedSearch):
"""Search that explores single-parameter perturbations around the current best."""

Expand Down Expand Up @@ -87,6 +107,8 @@ def __init__(
self.num_neighbors_cap = num_neighbors_cap
self.compile_timeout_lower_bound = compile_timeout_lower_bound
self.compile_timeout_quantile = compile_timeout_quantile
self.visited: set[Config] = set()
self.search_copies: list[PatternSearchCopy] = []

@classmethod
def get_kwargs_from_profile(
Expand Down Expand Up @@ -128,17 +150,18 @@ def _generate_initial_population_flat(self) -> list[FlatConfig]:
return pop
return self.config_gen.random_population_flat(self.initial_population)

def _autotune(self) -> Config:
initial_population_name = self.initial_population_strategy.name
def _init_search(self) -> None:
self.log(
f"Starting PatternSearch with initial_population={initial_population_name}, copies={self.copies}, max_generations={self.max_generations}"
f"Starting {type(self).__name__} with initial_population={self.initial_population_strategy.name},"
f" copies={self.copies},"
f" max_generations={self.max_generations}"
)
visited: set[Config] = set()
self.visited.clear()
self.population = []
for flat_config in self._generate_initial_population_flat():
member = self.make_unbenchmarked(flat_config)
if member.config not in visited:
visited.add(member.config)
if member.config not in self.visited:
self.visited.add(member.config)
self.population.append(member)
self.parallel_benchmark_population(self.population, desc="Initial population")

Expand All @@ -163,20 +186,36 @@ def _autotune(self) -> Config:
if not starting_points:
raise exc.NoConfigFound

search_copies = [self._pattern_search_from(m, visited) for m in starting_points]
for generation in range(1, self.max_generations + 1):
self.search_copies = [PatternSearchCopy(current=m) for m in starting_points]
self.set_generation(1)

def _autotune(self) -> Config:
for generation in range(self._current_generation, self.max_generations + 1):
self.set_generation(generation)
prior_best = self.best
new_population = {id(prior_best): prior_best}
num_neighbors = 0
num_active = 0
for search_copy in search_copies:
added = next(search_copy, ())
if added:
assert len(added) > 1
num_active += 1
num_neighbors += len(added) - 1
for member in added:
new_population[id(member)] = member
active_copies: list[tuple[PatternSearchCopy, list[PopulationMember]]] = []
for search_copy in self.search_copies:
if search_copy.stopped:
continue
candidates = [search_copy.current]
for flat_config in self._generate_neighbors(
search_copy.current.flat_values
):
new_member = self.make_unbenchmarked(flat_config)
if new_member.config not in self.visited:
self.visited.add(new_member.config)
candidates.append(new_member)
if len(candidates) <= 1:
search_copy.stopped = True
continue
num_active += 1
num_neighbors += len(candidates) - 1
for member in candidates:
new_population[id(member)] = member
active_copies.append((search_copy, candidates))
if num_active == 0:
break

Expand All @@ -189,46 +228,35 @@ def _autotune(self) -> Config:
# compile any unbenchmarked members in parallel
unbenchmarked = [m for m in self.population if len(m.perfs) == 0]
if unbenchmarked:
self.set_generation(generation)
self.parallel_benchmark_population(
unbenchmarked, desc=f"Generation {generation}:"
)
# higher-accuracy rebenchmark
self.rebenchmark_population(
self.population, desc=f"Generation {generation}: verifying top configs"
)

# Update each search copy after rebenchmarking (uses refined perf values)
for search_copy, candidates in active_copies:
best = min(candidates, key=performance)
if self._check_early_stopping(best, search_copy.current):
if (
search_copy.patience_remaining is not None
and search_copy.patience_remaining > 0
):
search_copy.patience_remaining -= 1
else:
search_copy.stopped = True
if not search_copy.stopped:
search_copy.current = best

# Log final statistics for this generation
self.log(f"Generation {generation} complete:", self.statistics)

# Run finishing phase to simplify the best configuration
best = self.run_finishing_phase(self.best, self.finishing_rounds)
return best.config

def _pattern_search_from(
self, current: PopulationMember, visited: set[Config]
) -> Iterator[list[PopulationMember]]:
"""
Run a single copy of pattern search from the given starting point.

We use a generator and yield the new population at each generation so that we can
run multiple copies of pattern search in parallel.
"""
for _ in range(self.max_generations):
candidates = [current]
for flat_config in self._generate_neighbors(current.flat_values):
new_member = self.make_unbenchmarked(flat_config)
if new_member.config not in visited:
visited.add(new_member.config)
candidates.append(new_member)
if len(candidates) <= 1:
return # no new candidates, stop searching
yield candidates # yield new population to benchmark in parallel
# update search copy and check early stopping criteria
best = min(candidates, key=performance)
if self._check_early_stopping(best, current):
return
current = best

def _check_early_stopping(
self, best: PopulationMember, current: PopulationMember
) -> bool:
Expand Down
Loading
Loading