11from __future__ import annotations
22
3+ import contextlib
34import functools
45import math
56import operator
2930from helion ._testing import skipIfTileIR
3031import helion .language as hl
3132
33+
34+ def requires_fusion_support (test_fn ):
35+ """Decorator: when fusion is unsupported, assert the upgrade error instead of running."""
36+
37+ @functools .wraps (test_fn )
38+ def wrapper (self , * args , ** kwargs ):
39+ ctx = (
40+ contextlib .nullcontext ()
41+ if supports_torch_compile_fusion ()
42+ else self .assertRaisesRegex (
43+ RuntimeError ,
44+ "torch_compile_fusion=True requires PyTorch nightly build" ,
45+ )
46+ )
47+ with ctx :
48+ test_fn (self , * args , ** kwargs )
49+
50+ return wrapper
51+
52+
3253# -----------------------------------------------------------------------------
3354# Basic Operations (no mutation, return new tensor)
3455# -----------------------------------------------------------------------------
@@ -494,11 +515,11 @@ def _run_compile_test(
494515 expected_num_kernels_ref : int | None = None ,
495516 ):
496517 """Run torch.compile test comparing eager vs compiled execution."""
497- if allow_torch_compile_fusion :
498- if not supports_torch_compile_fusion ():
499- self . skipTest (
500- "torch.compile fusion requires ExternalTritonTemplateKernel support"
501- )
518+ if allow_torch_compile_fusion and not supports_torch_compile_fusion () :
519+ expected_error = (
520+ RuntimeError ,
521+ "torch_compile_fusion=True requires PyTorch nightly build" ,
522+ )
502523
503524 # Reset specific kernels and configure fusion setting
504525 for kernel in kernels :
@@ -2764,11 +2785,6 @@ def test_dynamic_shapes_rejects_static_shapes(self, allow_torch_compile_fusion):
27642785 sizes, producing wrong results at runtime. We now raise a clear
27652786 error instead.
27662787 """
2767- if not supports_torch_compile_fusion ():
2768- self .skipTest (
2769- "static_shapes check requires HOP lowering path "
2770- "(ExternalTritonTemplateKernel support)"
2771- )
27722788
27732789 def f (x : torch .Tensor , y : torch .Tensor , * , _kernels = (k_add ,)) -> torch .Tensor :
27742790 return _kernels [0 ](x , y )
@@ -2781,7 +2797,9 @@ def f(x: torch.Tensor, y: torch.Tensor, *, _kernels=(k_add,)) -> torch.Tensor:
27812797 kernels = [k_add ],
27822798 dynamic = True ,
27832799 allow_torch_compile_fusion = allow_torch_compile_fusion ,
2784- expected_error = (RuntimeError , "static_shapes=True.*dynamic=True" ),
2800+ expected_error = (RuntimeError , "static_shapes=True.*dynamic=True" )
2801+ if supports_torch_compile_fusion ()
2802+ else None ,
27852803 )
27862804
27872805 @parametrize ("allow_torch_compile_fusion" , (True , False ))
@@ -4137,16 +4155,13 @@ def compare(actual, expected):
41374155 expected_num_kernels_ref = 1 ,
41384156 )
41394157
4158+ @requires_fusion_support
41404159 @parametrize ("allow_torch_compile_fusion" , (True , False ))
41414160 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
41424161 def test_symint_return_from_tensor_shape (self , allow_torch_compile_fusion ):
41434162 """Test: kernel returning SymInt (tensor shape) with dynamic shapes."""
41444163 if not allow_torch_compile_fusion :
41454164 self .skipTest ("Only testing with torch.compile fusion enabled" )
4146- if not supports_torch_compile_fusion ():
4147- self .skipTest (
4148- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4149- )
41504165
41514166 @helion .kernel (
41524167 autotune_effort = "none" , static_shapes = False , torch_compile_fusion = True
@@ -4568,15 +4583,12 @@ def patched_load(code, *args, **kwargs):
45684583
45694584 # --- Autotune-with-fusion tests ---
45704585
4586+ @requires_fusion_support
45714587 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
45724588 @parametrize ("autotune_with_fusion" , (True , False ))
45734589 def test_autotune_fusion_aware_vs_default (self , autotune_with_fusion ):
45744590 """When fusion-aware autotuning is on, each config is benchmarked as fused code;
45754591 when off, the pre-existing BoundKernel config is reused without recompilation."""
4576- if not supports_torch_compile_fusion ():
4577- self .skipTest (
4578- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4579- )
45804592
45814593 kernel = self ._make_autotune_kernel (
45824594 autotune_with_torch_compile_fusion = autotune_with_fusion
@@ -4632,13 +4644,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
46324644 result_direct = kernel (x .clone (), y .clone ())
46334645 torch .testing .assert_close (result_direct , x + y )
46344646
4647+ @requires_fusion_support
46354648 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
46364649 def test_autotune_fusion_recompile (self ):
46374650 """Recompile with shared BoundKernel still produces fused code."""
4638- if not supports_torch_compile_fusion ():
4639- self .skipTest (
4640- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4641- )
46424651
46434652 kernel = self ._make_autotune_kernel ()
46444653
@@ -4676,15 +4685,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
46764685 self .assertIn (pp , code , "Must have prologue fusion" )
46774686 self .assertEqual (code .count ("@triton.jit" ), 1 , "Single fused kernel" )
46784687
4688+ @requires_fusion_support
46794689 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
46804690 def test_autotune_different_epilogues (self ):
46814691 """Different epilogues (relu vs sigmoid) trigger separate autotuning."""
4682- if not supports_torch_compile_fusion ():
4683- self .skipTest (
4684- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4685- )
4686- from helion ._compiler ._inductor .template_buffer import HelionTemplateBuffer
4687-
46884692 kernel = self ._make_autotune_kernel ()
46894693 captured_codes , patch_ctx = self ._make_code_capture ()
46904694
@@ -4716,6 +4720,8 @@ def g(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
47164720 self .assertGreater (len (sigmoid_only ), 0 , "Must have sigmoid-only kernel(s)" )
47174721
47184722 # Different epilogues must produce separate fusion-context cache entries
4723+ from helion ._compiler ._inductor .template_buffer import HelionTemplateBuffer
4724+
47194725 bk = next (iter (kernel ._bound_kernels .values ()))
47204726 bk_cache = HelionTemplateBuffer ._fusion_config_cache .get (bk )
47214727 self .assertIsNotNone (
@@ -4737,13 +4743,10 @@ def g(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
47374743 )
47384744 self .assertEqual (len (captured_codes ), 0 , "Re-run must reuse cached kernels" )
47394745
4746+ @requires_fusion_support
47404747 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
47414748 def test_autotune_different_shapes (self ):
47424749 """Different input shapes trigger re-autotuning with fused kernels."""
4743- if not supports_torch_compile_fusion ():
4744- self .skipTest (
4745- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4746- )
47474750
47484751 kernel = self ._make_autotune_kernel ()
47494752 captured_codes , patch_ctx = self ._make_code_capture ()
@@ -4792,13 +4795,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
47924795 second_fused_count , 0 , "New shape must trigger new fused compilations"
47934796 )
47944797
4798+ @requires_fusion_support
47954799 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
47964800 def test_autotune_same_epilogue_cache (self ):
47974801 """Same kernel + same epilogue called twice → second hits fusion cache."""
4798- if not supports_torch_compile_fusion ():
4799- self .skipTest (
4800- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4801- )
48024802
48034803 kernel = self ._make_autotune_kernel ()
48044804 captured_codes , patch_ctx = self ._make_code_capture ()
@@ -4864,6 +4864,7 @@ def f_double(
48644864 f"(got { double_call_fused_count } new fused compilations)" ,
48654865 )
48664866
4867+ @requires_fusion_support
48674868 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
48684869 def test_standalone_call_after_fusion_triggers_autotuning (self ):
48694870 """Standalone call after torch.compile with fusion must trigger its own autotuning.
@@ -4873,12 +4874,6 @@ def test_standalone_call_after_fusion_triggers_autotuning(self):
48734874 workload. A subsequent direct call must trigger autotuning for
48744875 the unfused context rather than silently reusing the fused config.
48754876 """
4876- if not supports_torch_compile_fusion ():
4877- self .skipTest (
4878- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4879- )
4880- from helion ._compiler ._inductor .template_buffer import HelionTemplateBuffer
4881- from helion .runtime .kernel import BoundKernel
48824877
48834878 kernel = self ._make_autotune_kernel ()
48844879
@@ -4900,6 +4895,9 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
49004895 torch .testing .assert_close (result_fused , expected_fused , rtol = 1e-4 , atol = 1e-4 )
49014896
49024897 # Spy on compile_config to verify the standalone call triggers autotuning.
4898+ from helion ._compiler ._inductor .template_buffer import HelionTemplateBuffer
4899+ from helion .runtime .kernel import BoundKernel
4900+
49034901 compile_config_calls : list [bool ] = []
49044902 original_compile_config = BoundKernel .compile_config
49054903
@@ -4932,21 +4930,21 @@ def tracking_compile_config(self_bk, *args, **kwargs):
49324930 "Fusion config cache should have entries from the torch.compile call" ,
49334931 )
49344932
4933+ @requires_fusion_support
49354934 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
49364935 def test_autotune_fused_vs_unfused_config_stored_separately (self ):
49374936 """Unfused config (bk._config) and fused config (_fusion_config_cache) are independent."""
4938- if not supports_torch_compile_fusion ():
4939- self .skipTest (
4940- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4941- )
4942- from helion ._compiler ._inductor .template_buffer import HelionTemplateBuffer
49434937 from helion .runtime .config import Config
49444938
49454939 kernel = self ._make_autotune_kernel ()
49464940
49474941 kernel .reset ()
49484942 torch ._dynamo .reset ()
4949- HelionTemplateBuffer ._fusion_config_cache .clear ()
4943+
4944+ if supports_torch_compile_fusion ():
4945+ from helion ._compiler ._inductor .template_buffer import HelionTemplateBuffer
4946+
4947+ HelionTemplateBuffer ._fusion_config_cache .clear ()
49504948
49514949 x = torch .randn (128 , device = DEVICE , dtype = torch .float32 )
49524950 y = torch .randn (128 , device = DEVICE , dtype = torch .float32 )
@@ -4989,13 +4987,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
49894987 f"Fusion cache entry { fusion_key !r} must be a Config" ,
49904988 )
49914989
4990+ @requires_fusion_support
49924991 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
49934992 def test_autotune_epilogue_only_fusion (self ):
49944993 """Fusion-aware autotuning works with epilogue only (no prologue)."""
4995- if not supports_torch_compile_fusion ():
4996- self .skipTest (
4997- "torch.compile fusion requires ExternalTritonTemplateKernel support"
4998- )
49994994
50004995 kernel = self ._make_autotune_kernel ()
50014996 captured_codes , patch_ctx = self ._make_code_capture ()
@@ -5023,13 +5018,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
50235018 for code in fused :
50245019 self .assertNotIn (pp , code , "Must NOT have prologue in epilogue-only test" )
50255020
5021+ @requires_fusion_support
50265022 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
50275023 def test_autotune_prologue_only_fusion (self ):
50285024 """Fusion-aware autotuning works with prologue only (no epilogue)."""
5029- if not supports_torch_compile_fusion ():
5030- self .skipTest (
5031- "torch.compile fusion requires ExternalTritonTemplateKernel support"
5032- )
50335025
50345026 kernel = self ._make_autotune_kernel ()
50355027 captured_codes , patch_ctx = self ._make_code_capture ()
@@ -5059,13 +5051,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
50595051 for code in fused :
50605052 self .assertNotIn (ep , code , "Must NOT have epilogue in prologue-only test" )
50615053
5054+ @requires_fusion_support
50625055 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
50635056 def test_autotune_bare_kernel_no_prologue_epilogue (self ):
50645057 """Fusion-aware autotuning does not break when Inductor has no prologue or epilogue to fuse."""
5065- if not supports_torch_compile_fusion ():
5066- self .skipTest (
5067- "torch.compile fusion requires ExternalTritonTemplateKernel support"
5068- )
50695058
50705059 kernel = self ._make_autotune_kernel ()
50715060
@@ -5125,6 +5114,11 @@ def k_add_no_configs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
51255114
51265115 # torch.compile with fusion — must not crash and must not silently
51275116 # reuse the cached unfused config.
5117+ torch ._dynamo .reset ()
5118+
5119+ def f (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
5120+ return torch .relu (k_add_no_configs (x , y )) + 1.0
5121+
51285122 from helion ._compiler ._inductor .template_buffer import _FusionAutotuneAdapter
51295123
51305124 bench_compile_called : list [bool ] = []
@@ -5134,11 +5128,6 @@ def tracking_bench(adapter_self, config=None, **kwargs):
51345128 bench_compile_called .append (True )
51355129 return original_bench (adapter_self , config , ** kwargs )
51365130
5137- torch ._dynamo .reset ()
5138-
5139- def f (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
5140- return torch .relu (k_add_no_configs (x , y )) + 1.0
5141-
51425131 with patch .object (
51435132 _FusionAutotuneAdapter , "bench_compile_config" , tracking_bench
51445133 ):
@@ -5158,14 +5147,11 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
51585147 "bench_compile_config, not reuse the cached unfused config" ,
51595148 )
51605149
5150+ @requires_fusion_support
51615151 @skipIfTileIR ("torch.compile missing kernel metadata on tileir" )
51625152 @patch .object (k_rms_norm .settings , "torch_compile_fusion" , True )
51635153 def test_inductor_output_code_has_helion_generated_triton_kernel (self ):
51645154 """Verify Helion-specific patterns appear in inductor output code."""
5165- if not supports_torch_compile_fusion ():
5166- self .skipTest (
5167- "torch.compile fusion requires ExternalTritonTemplateKernel support"
5168- )
51695155
51705156 def f (x , weight , out_bias , res_bias ):
51715157 x_processed = torch .relu (x ) + 0.5
0 commit comments