Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .logger import maybe_dump_triton_failure
from .metrics import AutotuneMetrics
from .metrics import _run_post_autotune_hooks
from .precompile_future import PrecompileContext
from .precompile_future import PrecompileFuture as PrecompileFuture
from .precompile_future import _ExtractedLaunchArgs
from .progress_bar import iter_with_progress
Expand Down Expand Up @@ -329,7 +330,7 @@ class BaseSearch(BaseAutotuner):
_baseline_output: object
_mutated_arg_indices: Sequence[int] = []
_baseline_post_args: Sequence[object] | None
_jobs: int
_jobs: int = 1
_precompile_result_counter: count[int]
_effective_atol: float
_effective_rtol: float
Expand Down Expand Up @@ -860,6 +861,16 @@ def set_adaptive_compile_timeout(
f"bounds=[{min_seconds}s, {original_timeout}s])"
)

def _precompile_context(self) -> PrecompileContext:
"""Build the narrow context that PrecompileFuture needs."""
return PrecompileContext(
settings=self.settings,
log=self.log,
kernel=self.kernel,
args=self.args,
jobs=self._jobs,
)

def create_precompile_future(
self, config: Config, fn: CompiledConfig
) -> PrecompileFuture:
Expand All @@ -876,8 +887,9 @@ def create_precompile_future(
A ``PrecompileFuture`` that resolves to True on success or False on
failure/timeout when called.
"""
ctx = self._precompile_context()
if not self.settings.autotune_precompile:
return PrecompileFuture.skip(self, config, True)
return PrecompileFuture.skip(ctx, config, True)
mode = self.settings.autotune_precompile
if mode not in {"fork", "spawn"}:
raise exc.InvalidAPIUsage("autotune_precompile must be 'fork' or 'spawn'")
Expand All @@ -891,7 +903,7 @@ def create_precompile_future(
args = self.args

return PrecompileFuture.create(
search=self,
ctx=ctx,
config=config,
fn=fn,
args=args,
Expand Down
104 changes: 61 additions & 43 deletions helion/autotuner/precompile_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,31 @@
from ..runtime.config import Config
from ..runtime.kernel import BoundKernel
from ..runtime.kernel import CompiledConfig
from .base_search import BaseSearch
from ..runtime.settings import Settings
from .base_search import _AutotunableKernel
from .logger import AutotuningLogger


@dataclasses.dataclass
class PrecompileContext:
"""Narrow context that PrecompileFuture uses instead of a back-reference
to the full search object.

Attributes:
settings: Autotuning settings (compile timeout, ignore_errors, etc.).
log: Logger for warnings/debug messages.
kernel: The kernel being autotuned (used for error reporting).
args: The kernel arguments (used for repro logging on failure).
jobs: Maximum number of concurrent precompile processes.
"""

settings: Settings
log: AutotuningLogger
kernel: _AutotunableKernel
args: Sequence[object]
jobs: int


def _write_result_file(result_path: str, message: dict[str, object]) -> None:
tmp_path = f"{result_path}.tmp"
with open(tmp_path, "wb") as f:
Expand Down Expand Up @@ -277,7 +297,7 @@ class PrecompileFuture:
Wraps a child process where we are precompiling a kernel.

Attributes:
search (BaseSearch): The search object that initiated the precompilation.
ctx (PrecompileContext): The context needed to drive the precompilation.
config (Config): The configuration to be precompiled.
process (mp.Process | None): The process running the precompilation.
timeout (float): The timeout for the precompilation.
Expand All @@ -286,7 +306,7 @@ class PrecompileFuture:
ok (bool | None): The result of the precompilation (True if successful, False otherwise).
"""

search: BaseSearch
ctx: PrecompileContext
config: Config
process: mp.Process | None
timeout: float
Expand Down Expand Up @@ -336,11 +356,11 @@ def start(self) -> None:
self.process.start()

@staticmethod
def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:
def skip(ctx: PrecompileContext, config: Config, ok: bool) -> PrecompileFuture:
"""Dummy precompile future that is already done."""
ts = time.time()
return PrecompileFuture(
search=search,
ctx=ctx,
config=config,
process=None,
timeout=0,
Expand All @@ -356,7 +376,7 @@ def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:

@staticmethod
def create(
search: BaseSearch,
ctx: PrecompileContext,
config: Config,
fn: CompiledConfig,
args: Sequence[object],
Expand All @@ -369,11 +389,11 @@ def create(
construction. Returns a ``skip`` future when the kernel is already
compiled (fork mode only).
"""
mode = search.settings.autotune_precompile
decorator = search.kernel.format_kernel_decorator(config, search.settings)
mode = ctx.settings.autotune_precompile
decorator = ctx.kernel.format_kernel_decorator(config, ctx.settings)

if mode == "spawn":
ctx = mp.get_context("spawn")
mp_ctx = mp.get_context("spawn")
assert args_path is not None
try:
fn_spec = _serialize_compiled_fn(fn)
Expand All @@ -384,32 +404,32 @@ def create(
) from err
process = cast(
"mp.Process",
ctx.Process(
mp_ctx.Process(
target=_run_kernel_in_subprocess_spawn,
args=(fn_spec, args_path, result_path, decorator),
),
)
process.daemon = True
else:
precompiler = _prepare_precompiler_for_fork(
fn, args, config, search.kernel, decorator, search.log
fn, args, config, ctx.kernel, decorator, ctx.log
)
if precompiler is None:
return PrecompileFuture.skip(search, config, True)
ctx = mp.get_context("fork")
return PrecompileFuture.skip(ctx, config, True)
mp_ctx = mp.get_context("fork")
process = cast(
"mp.Process",
ctx.Process(
mp_ctx.Process(
target=_run_kernel_in_subprocess_fork,
args=(precompiler, config, search.kernel, result_path, decorator),
args=(precompiler, config, ctx.kernel, result_path, decorator),
),
)
process.daemon = True
return PrecompileFuture(
search=search,
ctx=ctx,
config=config,
process=process,
timeout=search.settings.autotune_compile_timeout,
timeout=ctx.settings.autotune_compile_timeout,
result_path=result_path,
)

Expand Down Expand Up @@ -476,7 +496,7 @@ def _wait_for_all_step(
futures: list[PrecompileFuture],
) -> list[PrecompileFuture]:
"""Start up to the concurrency cap, wait for progress, and return remaining futures."""
cap = futures[0].search._jobs if futures else 1
cap = futures[0].ctx.jobs if futures else 1
running = [f for f in futures if f.started and f.ok is None and f.is_alive()]

# Start queued futures up to the cap
Expand Down Expand Up @@ -565,16 +585,16 @@ def _mark_complete(self) -> bool:
process.join(10)
msg = f"Timeout after {self.elapsed:.0f}s compiling {self.config}"
if process.is_alive():
if not self.search.settings.autotune_ignore_errors:
self.search.log.warning(
if not self.ctx.settings.autotune_ignore_errors:
self.ctx.log.warning(
msg,
"(SIGKILL required)",
)
process.kill()
process.join()
else:
if not self.search.settings.autotune_ignore_errors:
self.search.log.warning(msg)
if not self.ctx.settings.autotune_ignore_errors:
self.ctx.log.warning(msg)

self.ok = False
self.failure_reason = "timeout"
Expand Down Expand Up @@ -643,30 +663,30 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
return
exc_obj = error.to_exception()
maybe_dump_triton_failure(
self.search.kernel,
self.ctx.kernel,
self.config,
exc_obj,
remote_traceback=error.traceback,
captured_output=error.captured_output,
)
classification = error.classification or classify_triton_exception(exc_obj)
ignore_errors = self.search.settings.autotune_ignore_errors
ignore_errors = self.ctx.settings.autotune_ignore_errors
if ignore_errors:
classification = "debug"
if classification == "raise":
if raise_on_raise:
self._remote_error_handled = True
decorator = self.search.kernel.format_kernel_decorator(
self.config, self.search.settings
decorator = self.ctx.kernel.format_kernel_decorator(
self.config, self.ctx.settings
)
log_generated_triton_code_debug(
self.search.log,
self.search.kernel,
self.ctx.log,
self.ctx.kernel,
self.config,
prefix=f"Generated Triton code for {decorator}:",
)
self.search.kernel.maybe_log_repro(
self.search.log.error, self.search.args, self.config
self.ctx.kernel.maybe_log_repro(
self.ctx.log.error, self.ctx.args, self.config
)
raise exc.TritonError(
error=f"{type(exc_obj).__qualname__}: {exc_obj}",
Expand All @@ -675,30 +695,28 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
) from exc_obj
return

decorator = self.search.kernel.format_kernel_decorator(
self.config, self.search.settings
decorator = self.ctx.kernel.format_kernel_decorator(
self.config, self.ctx.settings
)
log_generated_triton_code_debug(
self.search.log,
self.search.kernel,
self.ctx.log,
self.ctx.kernel,
self.config,
prefix=f"Generated Triton code for {decorator}:",
)
formatted = format_triton_compile_failure(
self.config, exc_obj, self.search.kernel
)
formatted = format_triton_compile_failure(self.config, exc_obj, self.ctx.kernel)
if error.traceback:
formatted = (
f"{formatted}\nRemote traceback (spawned process):\n{error.traceback}"
)
if classification == "warn":
self.search.log.warning(formatted)
self.search.kernel.maybe_log_repro(
self.search.log.warning, self.search.args, self.config
self.ctx.log.warning(formatted)
self.ctx.kernel.maybe_log_repro(
self.ctx.log.warning, self.ctx.args, self.config
)
elif not ignore_errors:
self.search.log.debug(formatted)
self.search.kernel.maybe_log_repro(
self.search.log.debug, self.search.args, self.config
self.ctx.log.debug(formatted)
self.ctx.kernel.maybe_log_repro(
self.ctx.log.debug, self.ctx.args, self.config
)
self._remote_error_handled = True
8 changes: 4 additions & 4 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def make_bad_config_produce_wrong_output(
"create_precompile_future",
side_effect=lambda config, fn: (
base_search_module.PrecompileFuture.skip(
search, config, True
search._precompile_context(), config, True
)
),
)
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def wrong_fn(*fn_args, **fn_kwargs):
"create_precompile_future",
side_effect=lambda config, fn: (
base_search_module.PrecompileFuture.skip(
search, config, True
search._precompile_context(), config, True
)
),
)
Expand Down Expand Up @@ -1298,7 +1298,7 @@ def make_bad_config_produce_wrong_output(
search,
"create_precompile_future",
side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip(
search, config, True
search._precompile_context(), config, True
),
):
# Bad config should be filtered out by accuracy check
Expand Down Expand Up @@ -2026,7 +2026,7 @@ def patched(*fn_args, **fn_kwargs):
search,
"create_precompile_future",
side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip(
search, config, True
search._precompile_context(), config, True
),
):
# bad_config has a few large diffs — custom check should accept it
Expand Down
Loading