Skip to content

Commit 5506ee0

Browse files
committed
Decouple PrecompileFuture from BaseSearch via PrecompileContext
Replace the back-reference to the full BaseSearch object with a narrow 5-field PrecompileContext dataclass (settings, log, kernel, args, jobs). This makes PrecompileFuture's dependencies explicit and testable in isolation. This change is part of a larger effort to introduce a BenchmarkProvider abstraction that encapsulates the benchmarking pipeline. PrecompileFuture currently holds a reference to BaseSearch to access 5 fields — when the BenchmarkProvider owns precompilation, it needs to pass its own context into PrecompileFuture without involving the search object. PrecompileContext makes that possible.
1 parent 78dba3a commit 5506ee0

File tree

3 files changed

+79
-49
lines changed

3 files changed

+79
-49
lines changed

helion/autotuner/base_search.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from .logger import maybe_dump_triton_failure
5757
from .metrics import AutotuneMetrics
5858
from .metrics import _run_post_autotune_hooks
59+
from .precompile_future import PrecompileContext
5960
from .precompile_future import PrecompileFuture as PrecompileFuture
6061
from .precompile_future import _ExtractedLaunchArgs
6162
from .progress_bar import iter_with_progress
@@ -329,7 +330,7 @@ class BaseSearch(BaseAutotuner):
329330
_baseline_output: object
330331
_mutated_arg_indices: Sequence[int] = []
331332
_baseline_post_args: Sequence[object] | None
332-
_jobs: int
333+
_jobs: int = 1
333334
_precompile_result_counter: count[int]
334335
_effective_atol: float
335336
_effective_rtol: float
@@ -860,6 +861,16 @@ def set_adaptive_compile_timeout(
860861
f"bounds=[{min_seconds}s, {original_timeout}s])"
861862
)
862863

864+
def _precompile_context(self) -> PrecompileContext:
865+
"""Build the narrow context that PrecompileFuture needs."""
866+
return PrecompileContext(
867+
settings=self.settings,
868+
log=self.log,
869+
kernel=self.kernel,
870+
args=self.args,
871+
jobs=self._jobs,
872+
)
873+
863874
def create_precompile_future(
864875
self, config: Config, fn: CompiledConfig
865876
) -> PrecompileFuture:
@@ -876,8 +887,9 @@ def create_precompile_future(
876887
A ``PrecompileFuture`` that resolves to True on success or False on
877888
failure/timeout when called.
878889
"""
890+
ctx = self._precompile_context()
879891
if not self.settings.autotune_precompile:
880-
return PrecompileFuture.skip(self, config, True)
892+
return PrecompileFuture.skip(ctx, config, True)
881893
mode = self.settings.autotune_precompile
882894
if mode not in {"fork", "spawn"}:
883895
raise exc.InvalidAPIUsage("autotune_precompile must be 'fork' or 'spawn'")
@@ -891,7 +903,7 @@ def create_precompile_future(
891903
args = self.args
892904

893905
return PrecompileFuture.create(
894-
search=self,
906+
ctx=ctx,
895907
config=config,
896908
fn=fn,
897909
args=args,

helion/autotuner/precompile_future.py

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,31 @@
4141
from ..runtime.config import Config
4242
from ..runtime.kernel import BoundKernel
4343
from ..runtime.kernel import CompiledConfig
44-
from .base_search import BaseSearch
44+
from ..runtime.settings import Settings
4545
from .base_search import _AutotunableKernel
4646
from .logger import AutotuningLogger
4747

4848

49+
@dataclasses.dataclass
50+
class PrecompileContext:
51+
"""Narrow context that PrecompileFuture uses instead of a back-reference
52+
to the full search object.
53+
54+
Attributes:
55+
settings: Autotuning settings (compile timeout, ignore_errors, etc.).
56+
log: Logger for warnings/debug messages.
57+
kernel: The kernel being autotuned (used for error reporting).
58+
args: The kernel arguments (used for repro logging on failure).
59+
jobs: Maximum number of concurrent precompile processes.
60+
"""
61+
62+
settings: Settings
63+
log: AutotuningLogger
64+
kernel: _AutotunableKernel
65+
args: Sequence[object]
66+
jobs: int
67+
68+
4969
def _write_result_file(result_path: str, message: dict[str, object]) -> None:
5070
tmp_path = f"{result_path}.tmp"
5171
with open(tmp_path, "wb") as f:
@@ -286,7 +306,7 @@ class PrecompileFuture:
286306
ok (bool | None): The result of the precompilation (True if successful, False otherwise).
287307
"""
288308

289-
search: BaseSearch
309+
ctx: PrecompileContext
290310
config: Config
291311
process: mp.Process | None
292312
timeout: float
@@ -336,11 +356,11 @@ def start(self) -> None:
336356
self.process.start()
337357

338358
@staticmethod
339-
def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:
359+
def skip(ctx: PrecompileContext, config: Config, ok: bool) -> PrecompileFuture:
340360
"""Dummy precompile future that is already done."""
341361
ts = time.time()
342362
return PrecompileFuture(
343-
search=search,
363+
ctx=ctx,
344364
config=config,
345365
process=None,
346366
timeout=0,
@@ -356,7 +376,7 @@ def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:
356376

357377
@staticmethod
358378
def create(
359-
search: BaseSearch,
379+
ctx: PrecompileContext,
360380
config: Config,
361381
fn: CompiledConfig,
362382
args: Sequence[object],
@@ -369,11 +389,11 @@ def create(
369389
construction. Returns a ``skip`` future when the kernel is already
370390
compiled (fork mode only).
371391
"""
372-
mode = search.settings.autotune_precompile
373-
decorator = search.kernel.format_kernel_decorator(config, search.settings)
392+
mode = ctx.settings.autotune_precompile
393+
decorator = ctx.kernel.format_kernel_decorator(config, ctx.settings)
374394

375395
if mode == "spawn":
376-
ctx = mp.get_context("spawn")
396+
mp_ctx = mp.get_context("spawn")
377397
assert args_path is not None
378398
try:
379399
fn_spec = _serialize_compiled_fn(fn)
@@ -384,32 +404,32 @@ def create(
384404
) from err
385405
process = cast(
386406
"mp.Process",
387-
ctx.Process(
407+
mp_ctx.Process(
388408
target=_run_kernel_in_subprocess_spawn,
389409
args=(fn_spec, args_path, result_path, decorator),
390410
),
391411
)
392412
process.daemon = True
393413
else:
394414
precompiler = _prepare_precompiler_for_fork(
395-
fn, args, config, search.kernel, decorator, search.log
415+
fn, args, config, ctx.kernel, decorator, ctx.log
396416
)
397417
if precompiler is None:
398-
return PrecompileFuture.skip(search, config, True)
399-
ctx = mp.get_context("fork")
418+
return PrecompileFuture.skip(ctx, config, True)
419+
mp_ctx = mp.get_context("fork")
400420
process = cast(
401421
"mp.Process",
402-
ctx.Process(
422+
mp_ctx.Process(
403423
target=_run_kernel_in_subprocess_fork,
404-
args=(precompiler, config, search.kernel, result_path, decorator),
424+
args=(precompiler, config, ctx.kernel, result_path, decorator),
405425
),
406426
)
407427
process.daemon = True
408428
return PrecompileFuture(
409-
search=search,
429+
ctx=ctx,
410430
config=config,
411431
process=process,
412-
timeout=search.settings.autotune_compile_timeout,
432+
timeout=ctx.settings.autotune_compile_timeout,
413433
result_path=result_path,
414434
)
415435

@@ -476,7 +496,7 @@ def _wait_for_all_step(
476496
futures: list[PrecompileFuture],
477497
) -> list[PrecompileFuture]:
478498
"""Start up to the concurrency cap, wait for progress, and return remaining futures."""
479-
cap = futures[0].search._jobs if futures else 1
499+
cap = futures[0].ctx.jobs if futures else 1
480500
running = [f for f in futures if f.started and f.ok is None and f.is_alive()]
481501

482502
# Start queued futures up to the cap
@@ -565,16 +585,16 @@ def _mark_complete(self) -> bool:
565585
process.join(10)
566586
msg = f"Timeout after {self.elapsed:.0f}s compiling {self.config}"
567587
if process.is_alive():
568-
if not self.search.settings.autotune_ignore_errors:
569-
self.search.log.warning(
588+
if not self.ctx.settings.autotune_ignore_errors:
589+
self.ctx.log.warning(
570590
msg,
571591
"(SIGKILL required)",
572592
)
573593
process.kill()
574594
process.join()
575595
else:
576-
if not self.search.settings.autotune_ignore_errors:
577-
self.search.log.warning(msg)
596+
if not self.ctx.settings.autotune_ignore_errors:
597+
self.ctx.log.warning(msg)
578598

579599
self.ok = False
580600
self.failure_reason = "timeout"
@@ -643,30 +663,30 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
643663
return
644664
exc_obj = error.to_exception()
645665
maybe_dump_triton_failure(
646-
self.search.kernel,
666+
self.ctx.kernel,
647667
self.config,
648668
exc_obj,
649669
remote_traceback=error.traceback,
650670
captured_output=error.captured_output,
651671
)
652672
classification = error.classification or classify_triton_exception(exc_obj)
653-
ignore_errors = self.search.settings.autotune_ignore_errors
673+
ignore_errors = self.ctx.settings.autotune_ignore_errors
654674
if ignore_errors:
655675
classification = "debug"
656676
if classification == "raise":
657677
if raise_on_raise:
658678
self._remote_error_handled = True
659-
decorator = self.search.kernel.format_kernel_decorator(
660-
self.config, self.search.settings
679+
decorator = self.ctx.kernel.format_kernel_decorator(
680+
self.config, self.ctx.settings
661681
)
662682
log_generated_triton_code_debug(
663-
self.search.log,
664-
self.search.kernel,
683+
self.ctx.log,
684+
self.ctx.kernel,
665685
self.config,
666686
prefix=f"Generated Triton code for {decorator}:",
667687
)
668-
self.search.kernel.maybe_log_repro(
669-
self.search.log.error, self.search.args, self.config
688+
self.ctx.kernel.maybe_log_repro(
689+
self.ctx.log.error, self.ctx.args, self.config
670690
)
671691
raise exc.TritonError(
672692
error=f"{type(exc_obj).__qualname__}: {exc_obj}",
@@ -675,30 +695,28 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
675695
) from exc_obj
676696
return
677697

678-
decorator = self.search.kernel.format_kernel_decorator(
679-
self.config, self.search.settings
698+
decorator = self.ctx.kernel.format_kernel_decorator(
699+
self.config, self.ctx.settings
680700
)
681701
log_generated_triton_code_debug(
682-
self.search.log,
683-
self.search.kernel,
702+
self.ctx.log,
703+
self.ctx.kernel,
684704
self.config,
685705
prefix=f"Generated Triton code for {decorator}:",
686706
)
687-
formatted = format_triton_compile_failure(
688-
self.config, exc_obj, self.search.kernel
689-
)
707+
formatted = format_triton_compile_failure(self.config, exc_obj, self.ctx.kernel)
690708
if error.traceback:
691709
formatted = (
692710
f"{formatted}\nRemote traceback (spawned process):\n{error.traceback}"
693711
)
694712
if classification == "warn":
695-
self.search.log.warning(formatted)
696-
self.search.kernel.maybe_log_repro(
697-
self.search.log.warning, self.search.args, self.config
713+
self.ctx.log.warning(formatted)
714+
self.ctx.kernel.maybe_log_repro(
715+
self.ctx.log.warning, self.ctx.args, self.config
698716
)
699717
elif not ignore_errors:
700-
self.search.log.debug(formatted)
701-
self.search.kernel.maybe_log_repro(
702-
self.search.log.debug, self.search.args, self.config
718+
self.ctx.log.debug(formatted)
719+
self.ctx.kernel.maybe_log_repro(
720+
self.ctx.log.debug, self.ctx.args, self.config
703721
)
704722
self._remote_error_handled = True

test/test_autotuner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,7 @@ def make_bad_config_produce_wrong_output(
10911091
"create_precompile_future",
10921092
side_effect=lambda config, fn: (
10931093
base_search_module.PrecompileFuture.skip(
1094-
search, config, True
1094+
search._precompile_context(), config, True
10951095
)
10961096
),
10971097
)
@@ -1173,7 +1173,7 @@ def wrong_fn(*fn_args, **fn_kwargs):
11731173
"create_precompile_future",
11741174
side_effect=lambda config, fn: (
11751175
base_search_module.PrecompileFuture.skip(
1176-
search, config, True
1176+
search._precompile_context(), config, True
11771177
)
11781178
),
11791179
)
@@ -1298,7 +1298,7 @@ def make_bad_config_produce_wrong_output(
12981298
search,
12991299
"create_precompile_future",
13001300
side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip(
1301-
search, config, True
1301+
search._precompile_context(), config, True
13021302
),
13031303
):
13041304
# Bad config should be filtered out by accuracy check
@@ -2026,7 +2026,7 @@ def patched(*fn_args, **fn_kwargs):
20262026
search,
20272027
"create_precompile_future",
20282028
side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip(
2029-
search, config, True
2029+
search._precompile_context(), config, True
20302030
),
20312031
):
20322032
# bad_config has a few large diffs — custom check should accept it

0 commit comments

Comments
 (0)