4141 from ..runtime .config import Config
4242 from ..runtime .kernel import BoundKernel
4343 from ..runtime .kernel import CompiledConfig
44- from .base_search import BaseSearch
44+ from .. runtime . settings import Settings
4545 from .base_search import _AutotunableKernel
4646 from .logger import AutotuningLogger
4747
4848
49+ @dataclasses .dataclass
50+ class PrecompileContext :
51+ """Narrow context that PrecompileFuture uses instead of a back-reference
52+ to the full search object.
53+
54+ Attributes:
55+ settings: Autotuning settings (compile timeout, ignore_errors, etc.).
56+ log: Logger for warnings/debug messages.
57+ kernel: The kernel being autotuned (used for error reporting).
58+ args: The kernel arguments (used for repro logging on failure).
59+ jobs: Maximum number of concurrent precompile processes.
60+ """
61+
62+ settings : Settings
63+ log : AutotuningLogger
64+ kernel : _AutotunableKernel
65+ args : Sequence [object ]
66+ jobs : int
67+
68+
4969def _write_result_file (result_path : str , message : dict [str , object ]) -> None :
5070 tmp_path = f"{ result_path } .tmp"
5171 with open (tmp_path , "wb" ) as f :
@@ -286,7 +306,7 @@ class PrecompileFuture:
286306 ok (bool | None): The result of the precompilation (True if successful, False otherwise).
287307 """
288308
289- search : BaseSearch
309+ ctx : PrecompileContext
290310 config : Config
291311 process : mp .Process | None
292312 timeout : float
@@ -336,11 +356,11 @@ def start(self) -> None:
336356 self .process .start ()
337357
338358 @staticmethod
339- def skip (search : BaseSearch , config : Config , ok : bool ) -> PrecompileFuture :
359+ def skip (ctx : PrecompileContext , config : Config , ok : bool ) -> PrecompileFuture :
340360 """Dummy precompile future that is already done."""
341361 ts = time .time ()
342362 return PrecompileFuture (
343- search = search ,
363+ ctx = ctx ,
344364 config = config ,
345365 process = None ,
346366 timeout = 0 ,
@@ -356,7 +376,7 @@ def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:
356376
357377 @staticmethod
358378 def create (
359- search : BaseSearch ,
379+ ctx : PrecompileContext ,
360380 config : Config ,
361381 fn : CompiledConfig ,
362382 args : Sequence [object ],
@@ -369,11 +389,11 @@ def create(
369389 construction. Returns a ``skip`` future when the kernel is already
370390 compiled (fork mode only).
371391 """
372- mode = search .settings .autotune_precompile
373- decorator = search .kernel .format_kernel_decorator (config , search .settings )
392+ mode = ctx .settings .autotune_precompile
393+ decorator = ctx .kernel .format_kernel_decorator (config , ctx .settings )
374394
375395 if mode == "spawn" :
376- ctx = mp .get_context ("spawn" )
396+ mp_ctx = mp .get_context ("spawn" )
377397 assert args_path is not None
378398 try :
379399 fn_spec = _serialize_compiled_fn (fn )
@@ -384,32 +404,32 @@ def create(
384404 ) from err
385405 process = cast (
386406 "mp.Process" ,
387- ctx .Process (
407+ mp_ctx .Process (
388408 target = _run_kernel_in_subprocess_spawn ,
389409 args = (fn_spec , args_path , result_path , decorator ),
390410 ),
391411 )
392412 process .daemon = True
393413 else :
394414 precompiler = _prepare_precompiler_for_fork (
395- fn , args , config , search .kernel , decorator , search .log
415+ fn , args , config , ctx .kernel , decorator , ctx .log
396416 )
397417 if precompiler is None :
398- return PrecompileFuture .skip (search , config , True )
399- ctx = mp .get_context ("fork" )
418+ return PrecompileFuture .skip (ctx , config , True )
419+ mp_ctx = mp .get_context ("fork" )
400420 process = cast (
401421 "mp.Process" ,
402- ctx .Process (
422+ mp_ctx .Process (
403423 target = _run_kernel_in_subprocess_fork ,
404- args = (precompiler , config , search .kernel , result_path , decorator ),
424+ args = (precompiler , config , ctx .kernel , result_path , decorator ),
405425 ),
406426 )
407427 process .daemon = True
408428 return PrecompileFuture (
409- search = search ,
429+ ctx = ctx ,
410430 config = config ,
411431 process = process ,
412- timeout = search .settings .autotune_compile_timeout ,
432+ timeout = ctx .settings .autotune_compile_timeout ,
413433 result_path = result_path ,
414434 )
415435
@@ -476,7 +496,7 @@ def _wait_for_all_step(
476496 futures : list [PrecompileFuture ],
477497 ) -> list [PrecompileFuture ]:
478498 """Start up to the concurrency cap, wait for progress, and return remaining futures."""
479- cap = futures [0 ].search . _jobs if futures else 1
499+ cap = futures [0 ].ctx . jobs if futures else 1
480500 running = [f for f in futures if f .started and f .ok is None and f .is_alive ()]
481501
482502 # Start queued futures up to the cap
@@ -565,16 +585,16 @@ def _mark_complete(self) -> bool:
565585 process .join (10 )
566586 msg = f"Timeout after { self .elapsed :.0f} s compiling { self .config } "
567587 if process .is_alive ():
568- if not self .search .settings .autotune_ignore_errors :
569- self .search .log .warning (
588+ if not self .ctx .settings .autotune_ignore_errors :
589+ self .ctx .log .warning (
570590 msg ,
571591 "(SIGKILL required)" ,
572592 )
573593 process .kill ()
574594 process .join ()
575595 else :
576- if not self .search .settings .autotune_ignore_errors :
577- self .search .log .warning (msg )
596+ if not self .ctx .settings .autotune_ignore_errors :
597+ self .ctx .log .warning (msg )
578598
579599 self .ok = False
580600 self .failure_reason = "timeout"
@@ -643,30 +663,30 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
643663 return
644664 exc_obj = error .to_exception ()
645665 maybe_dump_triton_failure (
646- self .search .kernel ,
666+ self .ctx .kernel ,
647667 self .config ,
648668 exc_obj ,
649669 remote_traceback = error .traceback ,
650670 captured_output = error .captured_output ,
651671 )
652672 classification = error .classification or classify_triton_exception (exc_obj )
653- ignore_errors = self .search .settings .autotune_ignore_errors
673+ ignore_errors = self .ctx .settings .autotune_ignore_errors
654674 if ignore_errors :
655675 classification = "debug"
656676 if classification == "raise" :
657677 if raise_on_raise :
658678 self ._remote_error_handled = True
659- decorator = self .search .kernel .format_kernel_decorator (
660- self .config , self .search .settings
679+ decorator = self .ctx .kernel .format_kernel_decorator (
680+ self .config , self .ctx .settings
661681 )
662682 log_generated_triton_code_debug (
663- self .search .log ,
664- self .search .kernel ,
683+ self .ctx .log ,
684+ self .ctx .kernel ,
665685 self .config ,
666686 prefix = f"Generated Triton code for { decorator } :" ,
667687 )
668- self .search .kernel .maybe_log_repro (
669- self .search .log .error , self .search .args , self .config
688+ self .ctx .kernel .maybe_log_repro (
689+ self .ctx .log .error , self .ctx .args , self .config
670690 )
671691 raise exc .TritonError (
672692 error = f"{ type (exc_obj ).__qualname__ } : { exc_obj } " ,
@@ -675,30 +695,28 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
675695 ) from exc_obj
676696 return
677697
678- decorator = self .search .kernel .format_kernel_decorator (
679- self .config , self .search .settings
698+ decorator = self .ctx .kernel .format_kernel_decorator (
699+ self .config , self .ctx .settings
680700 )
681701 log_generated_triton_code_debug (
682- self .search .log ,
683- self .search .kernel ,
702+ self .ctx .log ,
703+ self .ctx .kernel ,
684704 self .config ,
685705 prefix = f"Generated Triton code for { decorator } :" ,
686706 )
687- formatted = format_triton_compile_failure (
688- self .config , exc_obj , self .search .kernel
689- )
707+ formatted = format_triton_compile_failure (self .config , exc_obj , self .ctx .kernel )
690708 if error .traceback :
691709 formatted = (
692710 f"{ formatted } \n Remote traceback (spawned process):\n { error .traceback } "
693711 )
694712 if classification == "warn" :
695- self .search .log .warning (formatted )
696- self .search .kernel .maybe_log_repro (
697- self .search .log .warning , self .search .args , self .config
713+ self .ctx .log .warning (formatted )
714+ self .ctx .kernel .maybe_log_repro (
715+ self .ctx .log .warning , self .ctx .args , self .config
698716 )
699717 elif not ignore_errors :
700- self .search .log .debug (formatted )
701- self .search .kernel .maybe_log_repro (
702- self .search .log .debug , self .search .args , self .config
718+ self .ctx .log .debug (formatted )
719+ self .ctx .kernel .maybe_log_repro (
720+ self .ctx .log .debug , self .ctx .args , self .config
703721 )
704722 self ._remote_error_handled = True
0 commit comments