11from __future__ import annotations
22
3+ import dataclasses
34import enum
45import math
56from typing import TYPE_CHECKING
1112from .effort_profile import PATTERN_SEARCH_DEFAULTS
1213
1314if TYPE_CHECKING :
14- from collections .abc import Iterator
1515 from collections .abc import Sequence
1616
1717 from ..autotuner .effort_profile import AutotuneEffortProfile
@@ -31,6 +31,26 @@ class InitialPopulationStrategy(enum.Enum):
3131 """Start from default config plus up to 20 best matching cached configs from previous runs."""
3232
3333
34+ @dataclasses .dataclass
35+ class PatternSearchCopy :
36+ """
37+ Represents one copy of the pattern search.
38+
39+ Each copy explores from a different starting point. The `copies` parameter
40+ controls how many of these run in parallel.
41+ """
42+
43+ # The current best member for this search copy.
44+ current : PopulationMember
45+
46+ # Whether this search copy has stopped (no more candidates or early stopping).
47+ stopped : bool = False
48+
49+ # Remaining patience for early stopping (decremented when no improvement).
50+ # None means no patience tracking (stop immediately on no improvement).
51+ patience_remaining : int | None = None
52+
53+
3454class PatternSearch (PopulationBasedSearch ):
3555 """Search that explores single-parameter perturbations around the current best."""
3656
@@ -87,6 +107,8 @@ def __init__(
87107 self .num_neighbors_cap = num_neighbors_cap
88108 self .compile_timeout_lower_bound = compile_timeout_lower_bound
89109 self .compile_timeout_quantile = compile_timeout_quantile
110+ self .visited : set [Config ] = set ()
111+ self .search_copies : list [PatternSearchCopy ] = []
90112
91113 @classmethod
92114 def get_kwargs_from_profile (
@@ -128,17 +150,18 @@ def _generate_initial_population_flat(self) -> list[FlatConfig]:
128150 return pop
129151 return self .config_gen .random_population_flat (self .initial_population )
130152
131- def _autotune (self ) -> Config :
132- initial_population_name = self .initial_population_strategy .name
153+ def _init_search (self ) -> None :
133154 self .log (
134- f"Starting PatternSearch with initial_population={ initial_population_name } , copies={ self .copies } , max_generations={ self .max_generations } "
155+ f"Starting { type (self ).__name__ } with initial_population={ self .initial_population_strategy .name } ,"
156+ f" copies={ self .copies } ,"
157+ f" max_generations={ self .max_generations } "
135158 )
136- visited : set [ Config ] = set ()
159+ self . visited . clear ()
137160 self .population = []
138161 for flat_config in self ._generate_initial_population_flat ():
139162 member = self .make_unbenchmarked (flat_config )
140- if member .config not in visited :
141- visited .add (member .config )
163+ if member .config not in self . visited :
164+ self . visited .add (member .config )
142165 self .population .append (member )
143166 self .parallel_benchmark_population (self .population , desc = "Initial population" )
144167
@@ -163,20 +186,36 @@ def _autotune(self) -> Config:
163186 if not starting_points :
164187 raise exc .NoConfigFound
165188
166- search_copies = [self ._pattern_search_from (m , visited ) for m in starting_points ]
167- for generation in range (1 , self .max_generations + 1 ):
189+ self .search_copies = [PatternSearchCopy (current = m ) for m in starting_points ]
190+ self .set_generation (1 )
191+
192+ def _autotune (self ) -> Config :
193+ for generation in range (self ._current_generation , self .max_generations + 1 ):
194+ self .set_generation (generation )
168195 prior_best = self .best
169196 new_population = {id (prior_best ): prior_best }
170197 num_neighbors = 0
171198 num_active = 0
172- for search_copy in search_copies :
173- added = next (search_copy , ())
174- if added :
175- assert len (added ) > 1
176- num_active += 1
177- num_neighbors += len (added ) - 1
178- for member in added :
179- new_population [id (member )] = member
199+ active_copies : list [tuple [PatternSearchCopy , list [PopulationMember ]]] = []
200+ for search_copy in self .search_copies :
201+ if search_copy .stopped :
202+ continue
203+ candidates = [search_copy .current ]
204+ for flat_config in self ._generate_neighbors (
205+ search_copy .current .flat_values
206+ ):
207+ new_member = self .make_unbenchmarked (flat_config )
208+ if new_member .config not in self .visited :
209+ self .visited .add (new_member .config )
210+ candidates .append (new_member )
211+ if len (candidates ) <= 1 :
212+ search_copy .stopped = True
213+ continue
214+ num_active += 1
215+ num_neighbors += len (candidates ) - 1
216+ for member in candidates :
217+ new_population [id (member )] = member
218+ active_copies .append ((search_copy , candidates ))
180219 if num_active == 0 :
181220 break
182221
@@ -189,46 +228,35 @@ def _autotune(self) -> Config:
189228 # compile any unbenchmarked members in parallel
190229 unbenchmarked = [m for m in self .population if len (m .perfs ) == 0 ]
191230 if unbenchmarked :
192- self .set_generation (generation )
193231 self .parallel_benchmark_population (
194232 unbenchmarked , desc = f"Generation { generation } :"
195233 )
196234 # higher-accuracy rebenchmark
197235 self .rebenchmark_population (
198236 self .population , desc = f"Generation { generation } : verifying top configs"
199237 )
238+
239+ # Update each search copy after rebenchmarking (uses refined perf values)
240+ for search_copy , candidates in active_copies :
241+ best = min (candidates , key = performance )
242+ if self ._check_early_stopping (best , search_copy .current ):
243+ if (
244+ search_copy .patience_remaining is not None
245+ and search_copy .patience_remaining > 0
246+ ):
247+ search_copy .patience_remaining -= 1
248+ else :
249+ search_copy .stopped = True
250+ if not search_copy .stopped :
251+ search_copy .current = best
252+
200253 # Log final statistics for this generation
201254 self .log (f"Generation { generation } complete:" , self .statistics )
202255
203256 # Run finishing phase to simplify the best configuration
204257 best = self .run_finishing_phase (self .best , self .finishing_rounds )
205258 return best .config
206259
207- def _pattern_search_from (
208- self , current : PopulationMember , visited : set [Config ]
209- ) -> Iterator [list [PopulationMember ]]:
210- """
211- Run a single copy of pattern search from the given starting point.
212-
213- We use a generator and yield the new population at each generation so that we can
214- run multiple copies of pattern search in parallel.
215- """
216- for _ in range (self .max_generations ):
217- candidates = [current ]
218- for flat_config in self ._generate_neighbors (current .flat_values ):
219- new_member = self .make_unbenchmarked (flat_config )
220- if new_member .config not in visited :
221- visited .add (new_member .config )
222- candidates .append (new_member )
223- if len (candidates ) <= 1 :
224- return # no new candidates, stop searching
225- yield candidates # yield new population to benchmark in parallel
226- # update search copy and check early stopping criteria
227- best = min (candidates , key = performance )
228- if self ._check_early_stopping (best , current ):
229- return
230- current = best
231-
232260 def _check_early_stopping (
233261 self , best : PopulationMember , current : PopulationMember
234262 ) -> bool :
0 commit comments