Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helion/autotuner/aot_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def _get_cache_key(self) -> BoundKernelInMemoryCacheKey:
return self.kernel.kernel._create_bound_kernel_cache_key(
self.kernel,
tuple(self.args),
self.kernel.kernel.specialization_key(self.args),
self.kernel.kernel._base_specialization_key(self.args),
)

def _list_cache_entries(self) -> Sequence[tuple[str, LooseAutotuneCacheKey]]:
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _generate_key(self) -> LooseAutotuneCacheKey:
in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key(
self.kernel,
tuple(self.args),
self.kernel.kernel.specialization_key(self.args),
self.kernel.kernel._base_specialization_key(self.args),
)
kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn))
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest()
Expand Down
39 changes: 27 additions & 12 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
raise TypeError(
f"Too many arguments passed to the kernel, expected: {self._num_params} got: {len(args)}."
)
signature = self.specialization_key(args)
signature = self._base_specialization_key(args)
cache_key = self._get_bound_kernel_cache_key(args, signature)
bound_kernel = (
None if cache_key is None else self._bound_kernels.get(cache_key, None)
Expand All @@ -240,18 +240,12 @@ def bind(self, args: tuple[object, ...]) -> BoundKernel[_R]:
self._bound_kernels[cache_key] = bound_kernel
return bound_kernel

def specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]:
def _base_specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]:
"""
Generate a specialization key for the given arguments.

This method generates a unique key for the arguments based on their types
and the corresponding extractor functions defined in `_specialization_extractors`.

Args:
args: The arguments to generate a specialization key for.

Returns:
Hashable: A hashable key representing the specialization of the arguments.
Generate the base specialization key from input argument metadata only,
using the per-type extractor functions defined in `_specialization_extractors`,
without any extras discovered during compilation. Used internally for
_specialize_extra lookups.
"""
result = []
assert len(args) <= len(self._annotations)
Expand All @@ -266,6 +260,27 @@ def specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]:
return (*result, self._key_fn(*args))
return (*result,)

def specialization_key(self, args: Sequence[object]) -> tuple[Hashable, ...]:
"""
Generate the full specialization key for the given arguments, including
any additional specialization constraints discovered during compilation
(e.g. from hl.specialize() calls).

Before the first compilation, these extras are not yet known and the
key may be incomplete.

Args:
args: The arguments to generate a specialization key for.

Returns:
Hashable: A hashable key representing the specialization of the arguments.
"""
base = self._base_specialization_key(args)
extra_fns = self._specialize_extra.get(base)
if extra_fns is not None:
return base + tuple(s(args) for s in extra_fns)
return base

def _specialization_key(self, obj: object) -> Hashable:
"""
Helper used to generate a specialization key for the given object.
Expand Down
29 changes: 29 additions & 0 deletions test/test_specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,35 @@ def fn(x: torch.Tensor) -> torch.Tensor:
torch._dynamo.mark_static(x2, -1)
self.assertIsNot(fn.bind((x,)), fn.bind((x2,)))

@skipIfRefEager("specialization_key is not used in ref eager mode")
def test_specialization_key_includes_hl_specialize(self):
"""Test that specialization_key() includes hl.specialize() extras after bind()."""

@helion.kernel(static_shapes=False, autotune_effort="none")
def fn(x: torch.Tensor) -> torch.Tensor:
hl.specialize(x.size(-1))
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] * 2
return out

a = torch.randn([128, 64], device=DEVICE)
b = torch.randn([128, 32], device=DEVICE)

# Before bind: keys are equal (extras not yet known)
key_a_before = fn.specialization_key((a,))
key_b_before = fn.specialization_key((b,))
self.assertEqual(key_a_before, key_b_before)

# After bind: keys must differ because hl.specialize(x.size(-1))
# makes the kernel depend on the last dimension
fn.bind((a,))
fn.bind((b,))

key_a_after = fn.specialization_key((a,))
key_b_after = fn.specialization_key((b,))
self.assertNotEqual(key_a_after, key_b_after)


if __name__ == "__main__":
unittest.main()
Loading