@@ -2225,5 +2225,92 @@ def test_autotune_cache_invalid_raises(self):
22252225 bound .settings .autotuner_fn (bound , args )
22262226
22272227
2228+ @onlyBackends (["triton" ])
2229+ class TestConfigFilter (TestCase ):
2230+ """Tests for the config_filter setting."""
2231+
2232+ def _make_kernel_and_args (self , ** kernel_kwargs ):
2233+ @helion .kernel (autotune_log_level = 0 , ** kernel_kwargs )
2234+ def add (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
2235+ out = torch .empty_like (a )
2236+ for tile in hl .tile (out .size ()):
2237+ out [tile ] = a [tile ] + b [tile ]
2238+ return out
2239+
2240+ args = (
2241+ torch .randn ([128 ], device = DEVICE ),
2242+ torch .randn ([128 ], device = DEVICE ),
2243+ )
2244+ return add , args
2245+
2246+ def test_config_filter_skips_filtered_configs (self ) -> None :
2247+ """Filtered configs produce status='filtered' and perf=inf."""
2248+ cfg1 = helion .Config (block_sizes = [16 ], num_warps = 4 )
2249+ cfg2 = helion .Config (block_sizes = [32 ], num_warps = 4 )
2250+ cfg3 = helion .Config (block_sizes = [64 ], num_warps = 4 )
2251+
2252+ filtered_out : list [helion .Config ] = []
2253+
2254+ def my_filter (config : helion .Config ) -> bool :
2255+ if config .get ("block_sizes" ) == [32 ]:
2256+ filtered_out .append (config )
2257+ return False
2258+ return True
2259+
2260+ add , args = self ._make_kernel_and_args (
2261+ config_filter = my_filter , autotune_precompile = None
2262+ )
2263+ bound = add .bind (args )
2264+ search = FiniteSearch (bound , args , configs = [cfg1 , cfg2 , cfg3 ])
2265+ search ._prepare ()
2266+ results = search .benchmark_batch ([cfg1 , cfg2 , cfg3 ])
2267+
2268+ # cfg2 should be filtered
2269+ self .assertEqual (len (filtered_out ), 1 )
2270+ self .assertEqual (filtered_out [0 ].get ("block_sizes" ), [32 ])
2271+
2272+ statuses = {tuple (r .config .get ("block_sizes" , [])): r .status for r in results }
2273+ self .assertEqual (statuses [(16 ,)], "ok" )
2274+ self .assertEqual (statuses [(32 ,)], "filtered" )
2275+ self .assertEqual (statuses [(64 ,)], "ok" )
2276+
2277+ perfs = {tuple (r .config .get ("block_sizes" , [])): r .perf for r in results }
2278+ self .assertEqual (perfs [(32 ,)], float ("inf" ))
2279+
2280+ def test_config_filter_affects_autotune_winner (self ) -> None :
2281+ """The autotuner never picks a filtered config as the winner."""
2282+ # cfg_fast would normally win (smallest block = least work per kernel launch
2283+ # in this trivial test), but we filter it out.
2284+ cfg_fast = helion .Config (block_sizes = [16 ], num_warps = 4 )
2285+ cfg_slow = helion .Config (block_sizes = [128 ], num_warps = 4 )
2286+
2287+ def reject_small_blocks (config : helion .Config ) -> bool :
2288+ return (config .get ("block_sizes" ) or [0 ])[0 ] >= 64
2289+
2290+ add , args = self ._make_kernel_and_args (config_filter = reject_small_blocks )
2291+ bound = add .bind (args )
2292+ search = FiniteSearch (bound , args , configs = [cfg_fast , cfg_slow ])
2293+ winner = search .autotune ()
2294+ # cfg_fast is filtered out, so cfg_slow must win
2295+ self .assertEqual (winner .get ("block_sizes" ), [128 ])
2296+
2297+ def test_config_filter_none_is_noop (self ) -> None :
2298+ """When config_filter=None (default), all configs are benchmarked normally."""
2299+ cfg1 = helion .Config (block_sizes = [16 ], num_warps = 4 )
2300+ cfg2 = helion .Config (block_sizes = [32 ], num_warps = 4 )
2301+
2302+ add , args = self ._make_kernel_and_args (
2303+ autotune_precompile = None
2304+ ) # no config_filter
2305+ bound = add .bind (args )
2306+ search = FiniteSearch (bound , args , configs = [cfg1 , cfg2 ])
2307+ search ._prepare ()
2308+ results = search .benchmark_batch ([cfg1 , cfg2 ])
2309+
2310+ for result in results :
2311+ self .assertNotEqual (result .status , "filtered" )
2312+ self .assertFalse (math .isinf (result .perf ))
2313+
2314+
22282315if __name__ == "__main__" :
22292316 unittest .main ()
0 commit comments