Skip to content
2 changes: 1 addition & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._utils import counters
from .runtime.settings import _get_backend
from .runtime.settings import is_pallas_interpret
from helion.autotuner.base_search import _clone_args
from helion.autotuner.benchmark_provider import _clone_args

if _get_backend() == "pallas":
from .autotuner.benchmarking import compute_repeat_generic as compute_repeat
Expand Down
13 changes: 4 additions & 9 deletions helion/autotuner/aot_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import os
from pathlib import Path
import sys
import tempfile
import traceback
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -757,11 +756,9 @@ def measure_all_configs(self) -> list[tuple[Config, float]]:
old_precompile = self.autotuner.settings.autotune_precompile
self.autotuner.settings.autotune_precompile = None

# Set up tmpdir if needed (normally done inside autotune())
tmpdir_created = False
if self.autotuner._precompile_tmpdir is None:
self.autotuner._precompile_tmpdir = tempfile.TemporaryDirectory()
tmpdir_created = True
# Set up provider resources if needed (normally done inside autotune())
benchmark_provider = self.autotuner.benchmark_provider
benchmark_provider.setup()

try:
for i, config in enumerate(all_configs):
Expand Down Expand Up @@ -805,9 +802,7 @@ def measure_all_configs(self) -> list[tuple[Config, float]]:
finally:
# Restore settings
self.autotuner.settings.autotune_precompile = old_precompile
if tmpdir_created and self.autotuner._precompile_tmpdir is not None:
self.autotuner._precompile_tmpdir.cleanup()
self.autotuner._precompile_tmpdir = None
benchmark_provider.cleanup()

print(
f"[AOT measure] Completed: {len(results)}/{len(all_configs)} configs succeeded",
Expand Down
Loading
Loading