@@ -2068,5 +2068,92 @@ def test_autotune_cache_invalid_raises(self):
20682068 bound .settings .autotuner_fn (bound , args )
20692069
20702070
2071+ @onlyBackends (["triton" ])
2072+ class TestConfigFilter (TestCase ):
2073+ """Tests for the config_filter setting."""
2074+
2075+ def _make_kernel_and_args (self , ** kernel_kwargs ):
2076+ @helion .kernel (autotune_log_level = 0 , ** kernel_kwargs )
2077+ def add (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
2078+ out = torch .empty_like (a )
2079+ for tile in hl .tile (out .size ()):
2080+ out [tile ] = a [tile ] + b [tile ]
2081+ return out
2082+
2083+ args = (
2084+ torch .randn ([128 ], device = DEVICE ),
2085+ torch .randn ([128 ], device = DEVICE ),
2086+ )
2087+ return add , args
2088+
2089+ def test_config_filter_skips_filtered_configs (self ) -> None :
2090+ """Filtered configs produce status='filtered' and perf=inf."""
2091+ cfg1 = helion .Config (block_sizes = [16 ], num_warps = 4 )
2092+ cfg2 = helion .Config (block_sizes = [32 ], num_warps = 4 )
2093+ cfg3 = helion .Config (block_sizes = [64 ], num_warps = 4 )
2094+
2095+ filtered_out : list [helion .Config ] = []
2096+
2097+ def my_filter (config : helion .Config ) -> bool :
2098+ if config .get ("block_sizes" ) == [32 ]:
2099+ filtered_out .append (config )
2100+ return False
2101+ return True
2102+
2103+ add , args = self ._make_kernel_and_args (
2104+ config_filter = my_filter , autotune_precompile = None
2105+ )
2106+ bound = add .bind (args )
2107+ search = FiniteSearch (bound , args , configs = [cfg1 , cfg2 , cfg3 ])
2108+ search ._prepare ()
2109+ results = search .benchmark_batch ([cfg1 , cfg2 , cfg3 ])
2110+
2111+ # cfg2 should be filtered
2112+ self .assertEqual (len (filtered_out ), 1 )
2113+ self .assertEqual (filtered_out [0 ].get ("block_sizes" ), [32 ])
2114+
2115+ statuses = {tuple (r .config .get ("block_sizes" , [])): r .status for r in results }
2116+ self .assertEqual (statuses [(16 ,)], "ok" )
2117+ self .assertEqual (statuses [(32 ,)], "filtered" )
2118+ self .assertEqual (statuses [(64 ,)], "ok" )
2119+
2120+ perfs = {tuple (r .config .get ("block_sizes" , [])): r .perf for r in results }
2121+ self .assertEqual (perfs [(32 ,)], float ("inf" ))
2122+
2123+ def test_config_filter_affects_autotune_winner (self ) -> None :
2124+ """The autotuner never picks a filtered config as the winner."""
2125+ # cfg_fast would normally win (smallest block = least work per kernel launch
2126+ # in this trivial test), but we filter it out.
2127+ cfg_fast = helion .Config (block_sizes = [16 ], num_warps = 4 )
2128+ cfg_slow = helion .Config (block_sizes = [128 ], num_warps = 4 )
2129+
2130+ def reject_small_blocks (config : helion .Config ) -> bool :
2131+ return (config .get ("block_sizes" ) or [0 ])[0 ] >= 64
2132+
2133+ add , args = self ._make_kernel_and_args (config_filter = reject_small_blocks )
2134+ bound = add .bind (args )
2135+ search = FiniteSearch (bound , args , configs = [cfg_fast , cfg_slow ])
2136+ winner = search .autotune ()
2137+ # cfg_fast is filtered out, so cfg_slow must win
2138+ self .assertEqual (winner .get ("block_sizes" ), [128 ])
2139+
2140+ def test_config_filter_none_is_noop (self ) -> None :
2141+ """When config_filter=None (default), all configs are benchmarked normally."""
2142+ cfg1 = helion .Config (block_sizes = [16 ], num_warps = 4 )
2143+ cfg2 = helion .Config (block_sizes = [32 ], num_warps = 4 )
2144+
2145+ add , args = self ._make_kernel_and_args (
2146+ autotune_precompile = None
2147+ ) # no config_filter
2148+ bound = add .bind (args )
2149+ search = FiniteSearch (bound , args , configs = [cfg1 , cfg2 ])
2150+ search ._prepare ()
2151+ results = search .benchmark_batch ([cfg1 , cfg2 ])
2152+
2153+ for result in results :
2154+ self .assertNotEqual (result .status , "filtered" )
2155+ self .assertFalse (math .isinf (result .perf ))
2156+
2157+
20712158if __name__ == "__main__" :
20722159 unittest .main ()
0 commit comments