Skip to content

Commit 5686af9

Browse files
committed
custom config filter
stack-info: PR: #1847, branch: shunting314/stack/25
1 parent 161d0a0 commit 5686af9

File tree

3 files changed

+130
-5
lines changed

3 files changed

+130
-5
lines changed

helion/autotuner/base_search.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class BenchmarkResult(NamedTuple):
176176
config: Config
177177
fn: Callable[..., object]
178178
perf: float
179-
status: Literal["ok", "error", "timeout", "peer_compilation_fail"]
179+
status: Literal["ok", "error", "timeout", "peer_compilation_fail", "filtered"]
180180
compile_time: float | None
181181

182182

@@ -916,6 +916,33 @@ def _benchmark(
916916
A list of BenchmarkResult entries containing the configuration, compiled
917917
callable, measured performance, status, and compilation time.
918918
"""
919+
config_filter = self.settings.config_filter
920+
if config_filter is not None:
921+
passing_indices = [i for i, c in enumerate(configs) if config_filter(c)]
922+
if len(passing_indices) < len(configs):
923+
passing_configs = [configs[i] for i in passing_indices]
924+
inner_results = self._benchmark(passing_configs, desc=desc)
925+
inner_iter = iter(inner_results)
926+
merged: list[BenchmarkResult] = []
927+
passing_set = set(passing_indices)
928+
for i, config in enumerate(configs):
929+
if i in passing_set:
930+
merged.append(next(inner_iter))
931+
else:
932+
self.log.debug(
933+
f"Config filtered out by config_filter: {config!r}"
934+
)
935+
merged.append(
936+
BenchmarkResult(
937+
config=config,
938+
fn=lambda *a, **kw: None,
939+
perf=inf,
940+
status="filtered",
941+
compile_time=None,
942+
)
943+
)
944+
return merged
945+
919946
fns: list[Callable[..., object]] = []
920947
futures: list[PrecompileFuture] | None = None
921948
for config in configs:
@@ -965,7 +992,9 @@ def _benchmark(
965992
)
966993
else:
967994
compile_time = None
968-
status: Literal["ok", "error", "timeout", "peer_compilation_fail"]
995+
status: Literal[
996+
"ok", "error", "timeout", "peer_compilation_fail", "filtered"
997+
]
969998
if all(
970999
all_gather_object(
9711000
is_working, process_group_name=self.kernel.env.process_group_name
@@ -1162,9 +1191,9 @@ class PopulationMember:
11621191
perfs: list[float]
11631192
flat_values: FlatConfig
11641193
config: Config
1165-
status: Literal["ok", "error", "timeout", "peer_compilation_fail", "unknown"] = (
1166-
"unknown"
1167-
)
1194+
status: Literal[
1195+
"ok", "error", "timeout", "peer_compilation_fail", "filtered", "unknown"
1196+
] = "unknown"
11681197
compile_time: float | None = None
11691198

11701199
@property

helion/runtime/settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
if TYPE_CHECKING:
3030
from ..autotuner.base_search import BaseAutotuner
3131
from ..autotuner.pattern_search import InitialPopulationStrategy
32+
from .config import Config
3233
from .kernel import BoundKernel
3334

3435
_T = TypeVar("_T")
@@ -498,6 +499,7 @@ class _Settings:
498499
)
499500
)
500501
autotune_initial_population_strategy: InitialPopulation | None = None
502+
config_filter: Callable[[Config], bool] | None = None
501503

502504

503505
class Settings(_Settings):
@@ -639,6 +641,13 @@ class Settings(_Settings):
639641
"When set, takes precedence over the HELION_AUTOTUNER_INITIAL_POPULATION env var "
640642
"and the effort profile default."
641643
),
644+
"config_filter": (
645+
"Optional callable ``(config: Config) -> bool`` that the autotuner calls on every "
646+
"candidate config before compiling or benchmarking it. Configs for which the "
647+
"callable returns False are skipped entirely (no compilation, no benchmarking). "
648+
"Also filters the explicit ``configs=[...]`` list when one is provided. "
649+
"Pass as @helion.kernel(..., config_filter=my_filter_fn)."
650+
),
642651
}
643652

644653
def __init__(self, **settings: object) -> None:

test/test_autotuner.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,5 +2068,92 @@ def test_autotune_cache_invalid_raises(self):
20682068
bound.settings.autotuner_fn(bound, args)
20692069

20702070

2071+
@onlyBackends(["triton"])
2072+
class TestConfigFilter(TestCase):
2073+
"""Tests for the config_filter setting."""
2074+
2075+
def _make_kernel_and_args(self, **kernel_kwargs):
2076+
@helion.kernel(autotune_log_level=0, **kernel_kwargs)
2077+
def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
2078+
out = torch.empty_like(a)
2079+
for tile in hl.tile(out.size()):
2080+
out[tile] = a[tile] + b[tile]
2081+
return out
2082+
2083+
args = (
2084+
torch.randn([128], device=DEVICE),
2085+
torch.randn([128], device=DEVICE),
2086+
)
2087+
return add, args
2088+
2089+
def test_config_filter_skips_filtered_configs(self) -> None:
2090+
"""Filtered configs produce status='filtered' and perf=inf."""
2091+
cfg1 = helion.Config(block_sizes=[16], num_warps=4)
2092+
cfg2 = helion.Config(block_sizes=[32], num_warps=4)
2093+
cfg3 = helion.Config(block_sizes=[64], num_warps=4)
2094+
2095+
filtered_out: list[helion.Config] = []
2096+
2097+
def my_filter(config: helion.Config) -> bool:
2098+
if config.get("block_sizes") == [32]:
2099+
filtered_out.append(config)
2100+
return False
2101+
return True
2102+
2103+
add, args = self._make_kernel_and_args(
2104+
config_filter=my_filter, autotune_precompile=None
2105+
)
2106+
bound = add.bind(args)
2107+
search = FiniteSearch(bound, args, configs=[cfg1, cfg2, cfg3])
2108+
search._prepare()
2109+
results = search.benchmark_batch([cfg1, cfg2, cfg3])
2110+
2111+
# cfg2 should be filtered
2112+
self.assertEqual(len(filtered_out), 1)
2113+
self.assertEqual(filtered_out[0].get("block_sizes"), [32])
2114+
2115+
statuses = {tuple(r.config.get("block_sizes", [])): r.status for r in results}
2116+
self.assertEqual(statuses[(16,)], "ok")
2117+
self.assertEqual(statuses[(32,)], "filtered")
2118+
self.assertEqual(statuses[(64,)], "ok")
2119+
2120+
perfs = {tuple(r.config.get("block_sizes", [])): r.perf for r in results}
2121+
self.assertEqual(perfs[(32,)], float("inf"))
2122+
2123+
def test_config_filter_affects_autotune_winner(self) -> None:
2124+
"""The autotuner never picks a filtered config as the winner."""
2125+
# cfg_fast would normally win (smallest block = least work per kernel launch
2126+
# in this trivial test), but we filter it out.
2127+
cfg_fast = helion.Config(block_sizes=[16], num_warps=4)
2128+
cfg_slow = helion.Config(block_sizes=[128], num_warps=4)
2129+
2130+
def reject_small_blocks(config: helion.Config) -> bool:
2131+
return (config.get("block_sizes") or [0])[0] >= 64
2132+
2133+
add, args = self._make_kernel_and_args(config_filter=reject_small_blocks)
2134+
bound = add.bind(args)
2135+
search = FiniteSearch(bound, args, configs=[cfg_fast, cfg_slow])
2136+
winner = search.autotune()
2137+
# cfg_fast is filtered out, so cfg_slow must win
2138+
self.assertEqual(winner.get("block_sizes"), [128])
2139+
2140+
def test_config_filter_none_is_noop(self) -> None:
2141+
"""When config_filter=None (default), all configs are benchmarked normally."""
2142+
cfg1 = helion.Config(block_sizes=[16], num_warps=4)
2143+
cfg2 = helion.Config(block_sizes=[32], num_warps=4)
2144+
2145+
add, args = self._make_kernel_and_args(
2146+
autotune_precompile=None
2147+
) # no config_filter
2148+
bound = add.bind(args)
2149+
search = FiniteSearch(bound, args, configs=[cfg1, cfg2])
2150+
search._prepare()
2151+
results = search.benchmark_batch([cfg1, cfg2])
2152+
2153+
for result in results:
2154+
self.assertNotEqual(result.status, "filtered")
2155+
self.assertFalse(math.isinf(result.perf))
2156+
2157+
20712158
if __name__ == "__main__":
20722159
unittest.main()

0 commit comments

Comments
 (0)