Skip to content

Commit 98ab5e0

Browse files
committed
custom config filter
stack-info: PR: #1847, branch: shunting314/stack/25
1 parent 5318ddb commit 98ab5e0

File tree

3 files changed

+129
-5
lines changed

3 files changed

+129
-5
lines changed

helion/autotuner/base_search.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class BenchmarkResult(NamedTuple):
206206
config: Config
207207
fn: Callable[..., object]
208208
perf: float
209-
status: Literal["ok", "error", "timeout", "peer_compilation_fail"]
209+
status: Literal["ok", "error", "timeout", "peer_compilation_fail", "filtered"]
210210
compile_time: float | None
211211

212212

@@ -913,6 +913,33 @@ def _benchmark(
913913
A list of BenchmarkResult entries containing the configuration, compiled
914914
callable, measured performance, status, and compilation time.
915915
"""
916+
config_filter = self.settings.config_filter
917+
if config_filter is not None:
918+
passing_indices = [i for i, c in enumerate(configs) if config_filter(c)]
919+
if len(passing_indices) < len(configs):
920+
passing_configs = [configs[i] for i in passing_indices]
921+
inner_results = self._benchmark(passing_configs, desc=desc)
922+
inner_iter = iter(inner_results)
923+
merged: list[BenchmarkResult] = []
924+
passing_set = set(passing_indices)
925+
for i, config in enumerate(configs):
926+
if i in passing_set:
927+
merged.append(next(inner_iter))
928+
else:
929+
self.log.debug(
930+
f"Config filtered out by config_filter: {config!r}"
931+
)
932+
merged.append(
933+
BenchmarkResult(
934+
config=config,
935+
fn=lambda *a, **kw: None,
936+
perf=inf,
937+
status="filtered",
938+
compile_time=None,
939+
)
940+
)
941+
return merged
942+
916943
fns: list[Callable[..., object]] = []
917944
valid_configs: list[Config] = []
918945
futures: list[PrecompileFuture] | None = None
@@ -976,7 +1003,9 @@ def _benchmark(
9761003
)
9771004
else:
9781005
compile_time = None
979-
status: Literal["ok", "error", "timeout", "peer_compilation_fail"]
1006+
status: Literal[
1007+
"ok", "error", "timeout", "peer_compilation_fail", "filtered"
1008+
]
9801009
if all(
9811010
all_gather_object(
9821011
is_working, process_group_name=self.kernel.env.process_group_name
@@ -1174,9 +1203,9 @@ class PopulationMember:
11741203
perfs: list[float]
11751204
flat_values: FlatConfig
11761205
config: Config
1177-
status: Literal["ok", "error", "timeout", "peer_compilation_fail", "unknown"] = (
1178-
"unknown"
1179-
)
1206+
status: Literal[
1207+
"ok", "error", "timeout", "peer_compilation_fail", "filtered", "unknown"
1208+
] = "unknown"
11801209
compile_time: float | None = None
11811210

11821211
@property

helion/runtime/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
if TYPE_CHECKING:
3030
from ..autotuner.base_search import BaseAutotuner
3131
from ..autotuner.pattern_search import InitialPopulationStrategy
32+
from .config import Config
3233
from .kernel import BoundKernel
3334

3435
_T = TypeVar("_T")
@@ -513,6 +514,7 @@ class _Settings:
513514
_env_get_bool, "HELION_AUTOTUNE_WITH_TORCH_COMPILE_FUSION", False
514515
)
515516
)
517+
config_filter: Callable[[Config], bool] | None = None
516518

517519

518520
class Settings(_Settings):
@@ -658,6 +660,12 @@ class Settings(_Settings):
658660
"If True, allow torch.compile to fuse this Helion kernel with surrounding Inductor ops "
659661
"(prologue/epilogue) when used inside torch.compile. Default False. "
660662
"Set HELION_TORCH_COMPILE_FUSION=1 to enable globally."
663+
"config_filter": (
664+
"Optional callable ``(config: Config) -> bool`` that the autotuner calls on every "
665+
"candidate config before compiling or benchmarking it. Configs for which the "
666+
"callable returns False are skipped entirely (no compilation, no benchmarking). "
667+
"Also filters the explicit ``configs=[...]`` list when one is provided. "
668+
"Pass as @helion.kernel(..., config_filter=my_filter_fn)."
661669
),
662670
"autotune_with_torch_compile_fusion": (
663671
"If True, autotuning benchmarks the fused kernel (with epilogue/prologue) "

test/test_autotuner.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,5 +2225,92 @@ def test_autotune_cache_invalid_raises(self):
22252225
bound.settings.autotuner_fn(bound, args)
22262226

22272227

2228+
@onlyBackends(["triton"])
2229+
class TestConfigFilter(TestCase):
2230+
"""Tests for the config_filter setting."""
2231+
2232+
def _make_kernel_and_args(self, **kernel_kwargs):
2233+
@helion.kernel(autotune_log_level=0, **kernel_kwargs)
2234+
def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
2235+
out = torch.empty_like(a)
2236+
for tile in hl.tile(out.size()):
2237+
out[tile] = a[tile] + b[tile]
2238+
return out
2239+
2240+
args = (
2241+
torch.randn([128], device=DEVICE),
2242+
torch.randn([128], device=DEVICE),
2243+
)
2244+
return add, args
2245+
2246+
def test_config_filter_skips_filtered_configs(self) -> None:
2247+
"""Filtered configs produce status='filtered' and perf=inf."""
2248+
cfg1 = helion.Config(block_sizes=[16], num_warps=4)
2249+
cfg2 = helion.Config(block_sizes=[32], num_warps=4)
2250+
cfg3 = helion.Config(block_sizes=[64], num_warps=4)
2251+
2252+
filtered_out: list[helion.Config] = []
2253+
2254+
def my_filter(config: helion.Config) -> bool:
2255+
if config.get("block_sizes") == [32]:
2256+
filtered_out.append(config)
2257+
return False
2258+
return True
2259+
2260+
add, args = self._make_kernel_and_args(
2261+
config_filter=my_filter, autotune_precompile=None
2262+
)
2263+
bound = add.bind(args)
2264+
search = FiniteSearch(bound, args, configs=[cfg1, cfg2, cfg3])
2265+
search._prepare()
2266+
results = search.benchmark_batch([cfg1, cfg2, cfg3])
2267+
2268+
# cfg2 should be filtered
2269+
self.assertEqual(len(filtered_out), 1)
2270+
self.assertEqual(filtered_out[0].get("block_sizes"), [32])
2271+
2272+
statuses = {tuple(r.config.get("block_sizes", [])): r.status for r in results}
2273+
self.assertEqual(statuses[(16,)], "ok")
2274+
self.assertEqual(statuses[(32,)], "filtered")
2275+
self.assertEqual(statuses[(64,)], "ok")
2276+
2277+
perfs = {tuple(r.config.get("block_sizes", [])): r.perf for r in results}
2278+
self.assertEqual(perfs[(32,)], float("inf"))
2279+
2280+
def test_config_filter_affects_autotune_winner(self) -> None:
2281+
"""The autotuner never picks a filtered config as the winner."""
2282+
# cfg_fast would normally win (smallest block = least work per kernel launch
2283+
# in this trivial test), but we filter it out.
2284+
cfg_fast = helion.Config(block_sizes=[16], num_warps=4)
2285+
cfg_slow = helion.Config(block_sizes=[128], num_warps=4)
2286+
2287+
def reject_small_blocks(config: helion.Config) -> bool:
2288+
return (config.get("block_sizes") or [0])[0] >= 64
2289+
2290+
add, args = self._make_kernel_and_args(config_filter=reject_small_blocks)
2291+
bound = add.bind(args)
2292+
search = FiniteSearch(bound, args, configs=[cfg_fast, cfg_slow])
2293+
winner = search.autotune()
2294+
# cfg_fast is filtered out, so cfg_slow must win
2295+
self.assertEqual(winner.get("block_sizes"), [128])
2296+
2297+
def test_config_filter_none_is_noop(self) -> None:
2298+
"""When config_filter=None (default), all configs are benchmarked normally."""
2299+
cfg1 = helion.Config(block_sizes=[16], num_warps=4)
2300+
cfg2 = helion.Config(block_sizes=[32], num_warps=4)
2301+
2302+
add, args = self._make_kernel_and_args(
2303+
autotune_precompile=None
2304+
) # no config_filter
2305+
bound = add.bind(args)
2306+
search = FiniteSearch(bound, args, configs=[cfg1, cfg2])
2307+
search._prepare()
2308+
results = search.benchmark_batch([cfg1, cfg2])
2309+
2310+
for result in results:
2311+
self.assertNotEqual(result.status, "filtered")
2312+
self.assertFalse(math.isinf(result.perf))
2313+
2314+
22282315
if __name__ == "__main__":
22292316
unittest.main()

0 commit comments

Comments
 (0)