Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,9 +1345,11 @@ def _get_current_hardware_and_specialization(
hardware = get_device_name(extract_device(self.args))

inner_kernel = getattr(self.kernel, "kernel", None)
if inner_kernel is None or not hasattr(inner_kernel, "specialization_key"):
if inner_kernel is None or not hasattr(
inner_kernel, "_base_specialization_key"
):
return hardware, None
spec_key = inner_kernel.specialization_key(self.args)
spec_key = inner_kernel._base_specialization_key(self.args)
specialization_key = str(_normalize_spec_key(spec_key))

return hardware, specialization_key
Expand Down
68 changes: 68 additions & 0 deletions test/test_best_available.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,74 @@ def test_cache_matching_with_code_object_in_spec_key(self):
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].config.config["block_sizes"], [64, 128])

def test_find_similar_matches_with_specialize_extras(self):
"""FROM_BEST_AVAILABLE matches cache entries when hl.specialize() adds
extras to the full specialization key.

The cache stores _base_specialization_key (no extras) but the kernel's
specialization_key() appends hl.specialize() discoveries. The lookup
must use the base key so it matches the stored format.
"""
fingerprint = (("block_sizes", 2, 1, 1),)
fp_hash = hashlib.sha256(repr(fingerprint).encode("utf-8")).hexdigest()

base_spec_key = ("tensor_spec",)
# Full key has an extra element from hl.specialize(x.size(1))
full_spec_key = ("tensor_spec", 256)

with tempfile.TemporaryDirectory() as cache_dir:
# Cache entry stored with base key (as local_cache.py does)
self._write_best_config(
cache_dir,
"specialize.best_config",
hardware="NVIDIA GeForce RTX 4090",
spec_key=str(base_spec_key),
source_hash="hash1",
config_dict={"block_sizes": [64, 128], "num_warps": 4},
config_spec_hash=fp_hash,
flat_config=[64, 128, 4],
)

mock_search = MagicMock()
mock_search._skip_cache = False
mock_search.settings = MagicMock()
mock_search.settings.autotune_best_available_max_cache_scan = 500
mock_search.args = [torch.tensor([1.0], device=DEVICE)]
mock_search.config_spec = MagicMock()
mock_search.config_spec.structural_fingerprint_hash = MagicMock(
return_value=fp_hash
)

# Set up kernel with base key != full key (simulates hl.specialize())
mock_kernel = MagicMock()
mock_kernel._base_specialization_key = MagicMock(return_value=base_spec_key)
mock_kernel.specialization_key = MagicMock(return_value=full_spec_key)
mock_search.kernel.kernel = mock_kernel

# Use the REAL _get_current_hardware_and_specialization
mock_search._get_current_hardware_and_specialization = lambda: (
PopulationBasedSearch._get_current_hardware_and_specialization(
mock_search
)
)

with (
patch(
"helion.autotuner.local_cache.get_helion_cache_dir",
return_value=Path(cache_dir),
),
patch(
"helion.autotuner.base_search.get_device_name",
return_value="NVIDIA GeForce RTX 4090",
),
):
entries = PopulationBasedSearch._find_similar_cached_configs(
mock_search, max_configs=10
)

self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].config.config["block_sizes"], [64, 128])


class TestIterCacheEntries(unittest.TestCase):
"""Tests for the iter_cache_entries() module-level API in local_cache."""
Expand Down
Loading