Skip to content

Commit 655be3f

Browse files
authored
[Helion + torch.compile] Replace skipTest guards with error assertions for fusion tests (#1932)
1 parent 85bf24f commit 655be3f

File tree

1 file changed

+56
-70
lines changed

1 file changed

+56
-70
lines changed

test/test_torch_compile.py

Lines changed: 56 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import functools
45
import math
56
import operator
@@ -29,6 +30,26 @@
2930
from helion._testing import skipIfTileIR
3031
import helion.language as hl
3132

33+
34+
def requires_fusion_support(test_fn):
35+
"""Decorator: when fusion is unsupported, assert the upgrade error instead of running."""
36+
37+
@functools.wraps(test_fn)
38+
def wrapper(self, *args, **kwargs):
39+
ctx = (
40+
contextlib.nullcontext()
41+
if supports_torch_compile_fusion()
42+
else self.assertRaisesRegex(
43+
RuntimeError,
44+
"torch_compile_fusion=True requires PyTorch nightly build",
45+
)
46+
)
47+
with ctx:
48+
test_fn(self, *args, **kwargs)
49+
50+
return wrapper
51+
52+
3253
# -----------------------------------------------------------------------------
3354
# Basic Operations (no mutation, return new tensor)
3455
# -----------------------------------------------------------------------------
@@ -494,11 +515,11 @@ def _run_compile_test(
494515
expected_num_kernels_ref: int | None = None,
495516
):
496517
"""Run torch.compile test comparing eager vs compiled execution."""
497-
if allow_torch_compile_fusion:
498-
if not supports_torch_compile_fusion():
499-
self.skipTest(
500-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
501-
)
518+
if allow_torch_compile_fusion and not supports_torch_compile_fusion():
519+
expected_error = (
520+
RuntimeError,
521+
"torch_compile_fusion=True requires PyTorch nightly build",
522+
)
502523

503524
# Reset specific kernels and configure fusion setting
504525
for kernel in kernels:
@@ -2764,11 +2785,6 @@ def test_dynamic_shapes_rejects_static_shapes(self, allow_torch_compile_fusion):
27642785
sizes, producing wrong results at runtime. We now raise a clear
27652786
error instead.
27662787
"""
2767-
if not supports_torch_compile_fusion():
2768-
self.skipTest(
2769-
"static_shapes check requires HOP lowering path "
2770-
"(ExternalTritonTemplateKernel support)"
2771-
)
27722788

27732789
def f(x: torch.Tensor, y: torch.Tensor, *, _kernels=(k_add,)) -> torch.Tensor:
27742790
return _kernels[0](x, y)
@@ -2781,7 +2797,9 @@ def f(x: torch.Tensor, y: torch.Tensor, *, _kernels=(k_add,)) -> torch.Tensor:
27812797
kernels=[k_add],
27822798
dynamic=True,
27832799
allow_torch_compile_fusion=allow_torch_compile_fusion,
2784-
expected_error=(RuntimeError, "static_shapes=True.*dynamic=True"),
2800+
expected_error=(RuntimeError, "static_shapes=True.*dynamic=True")
2801+
if supports_torch_compile_fusion()
2802+
else None,
27852803
)
27862804

27872805
@parametrize("allow_torch_compile_fusion", (True, False))
@@ -4137,16 +4155,13 @@ def compare(actual, expected):
41374155
expected_num_kernels_ref=1,
41384156
)
41394157

4158+
@requires_fusion_support
41404159
@parametrize("allow_torch_compile_fusion", (True, False))
41414160
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
41424161
def test_symint_return_from_tensor_shape(self, allow_torch_compile_fusion):
41434162
"""Test: kernel returning SymInt (tensor shape) with dynamic shapes."""
41444163
if not allow_torch_compile_fusion:
41454164
self.skipTest("Only testing with torch.compile fusion enabled")
4146-
if not supports_torch_compile_fusion():
4147-
self.skipTest(
4148-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4149-
)
41504165

41514166
@helion.kernel(
41524167
autotune_effort="none", static_shapes=False, torch_compile_fusion=True
@@ -4568,15 +4583,12 @@ def patched_load(code, *args, **kwargs):
45684583

45694584
# --- Autotune-with-fusion tests ---
45704585

4586+
@requires_fusion_support
45714587
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
45724588
@parametrize("autotune_with_fusion", (True, False))
45734589
def test_autotune_fusion_aware_vs_default(self, autotune_with_fusion):
45744590
"""When fusion-aware autotuning is on, each config is benchmarked as fused code;
45754591
when off, the pre-existing BoundKernel config is reused without recompilation."""
4576-
if not supports_torch_compile_fusion():
4577-
self.skipTest(
4578-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4579-
)
45804592

45814593
kernel = self._make_autotune_kernel(
45824594
autotune_with_torch_compile_fusion=autotune_with_fusion
@@ -4632,13 +4644,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
46324644
result_direct = kernel(x.clone(), y.clone())
46334645
torch.testing.assert_close(result_direct, x + y)
46344646

4647+
@requires_fusion_support
46354648
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
46364649
def test_autotune_fusion_recompile(self):
46374650
"""Recompile with shared BoundKernel still produces fused code."""
4638-
if not supports_torch_compile_fusion():
4639-
self.skipTest(
4640-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4641-
)
46424651

46434652
kernel = self._make_autotune_kernel()
46444653

@@ -4676,15 +4685,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
46764685
self.assertIn(pp, code, "Must have prologue fusion")
46774686
self.assertEqual(code.count("@triton.jit"), 1, "Single fused kernel")
46784687

4688+
@requires_fusion_support
46794689
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
46804690
def test_autotune_different_epilogues(self):
46814691
"""Different epilogues (relu vs sigmoid) trigger separate autotuning."""
4682-
if not supports_torch_compile_fusion():
4683-
self.skipTest(
4684-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4685-
)
4686-
from helion._compiler._inductor.template_buffer import HelionTemplateBuffer
4687-
46884692
kernel = self._make_autotune_kernel()
46894693
captured_codes, patch_ctx = self._make_code_capture()
46904694

@@ -4716,6 +4720,8 @@ def g(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
47164720
self.assertGreater(len(sigmoid_only), 0, "Must have sigmoid-only kernel(s)")
47174721

47184722
# Different epilogues must produce separate fusion-context cache entries
4723+
from helion._compiler._inductor.template_buffer import HelionTemplateBuffer
4724+
47194725
bk = next(iter(kernel._bound_kernels.values()))
47204726
bk_cache = HelionTemplateBuffer._fusion_config_cache.get(bk)
47214727
self.assertIsNotNone(
@@ -4737,13 +4743,10 @@ def g(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
47374743
)
47384744
self.assertEqual(len(captured_codes), 0, "Re-run must reuse cached kernels")
47394745

4746+
@requires_fusion_support
47404747
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
47414748
def test_autotune_different_shapes(self):
47424749
"""Different input shapes trigger re-autotuning with fused kernels."""
4743-
if not supports_torch_compile_fusion():
4744-
self.skipTest(
4745-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4746-
)
47474750

47484751
kernel = self._make_autotune_kernel()
47494752
captured_codes, patch_ctx = self._make_code_capture()
@@ -4792,13 +4795,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
47924795
second_fused_count, 0, "New shape must trigger new fused compilations"
47934796
)
47944797

4798+
@requires_fusion_support
47954799
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
47964800
def test_autotune_same_epilogue_cache(self):
47974801
"""Same kernel + same epilogue called twice → second hits fusion cache."""
4798-
if not supports_torch_compile_fusion():
4799-
self.skipTest(
4800-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4801-
)
48024802

48034803
kernel = self._make_autotune_kernel()
48044804
captured_codes, patch_ctx = self._make_code_capture()
@@ -4864,6 +4864,7 @@ def f_double(
48644864
f"(got {double_call_fused_count} new fused compilations)",
48654865
)
48664866

4867+
@requires_fusion_support
48674868
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
48684869
def test_standalone_call_after_fusion_triggers_autotuning(self):
48694870
"""Standalone call after torch.compile with fusion must trigger its own autotuning.
@@ -4873,12 +4874,6 @@ def test_standalone_call_after_fusion_triggers_autotuning(self):
48734874
workload. A subsequent direct call must trigger autotuning for
48744875
the unfused context rather than silently reusing the fused config.
48754876
"""
4876-
if not supports_torch_compile_fusion():
4877-
self.skipTest(
4878-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4879-
)
4880-
from helion._compiler._inductor.template_buffer import HelionTemplateBuffer
4881-
from helion.runtime.kernel import BoundKernel
48824877

48834878
kernel = self._make_autotune_kernel()
48844879

@@ -4900,6 +4895,9 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
49004895
torch.testing.assert_close(result_fused, expected_fused, rtol=1e-4, atol=1e-4)
49014896

49024897
# Spy on compile_config to verify the standalone call triggers autotuning.
4898+
from helion._compiler._inductor.template_buffer import HelionTemplateBuffer
4899+
from helion.runtime.kernel import BoundKernel
4900+
49034901
compile_config_calls: list[bool] = []
49044902
original_compile_config = BoundKernel.compile_config
49054903

@@ -4932,21 +4930,21 @@ def tracking_compile_config(self_bk, *args, **kwargs):
49324930
"Fusion config cache should have entries from the torch.compile call",
49334931
)
49344932

4933+
@requires_fusion_support
49354934
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
49364935
def test_autotune_fused_vs_unfused_config_stored_separately(self):
49374936
"""Unfused config (bk._config) and fused config (_fusion_config_cache) are independent."""
4938-
if not supports_torch_compile_fusion():
4939-
self.skipTest(
4940-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4941-
)
4942-
from helion._compiler._inductor.template_buffer import HelionTemplateBuffer
49434937
from helion.runtime.config import Config
49444938

49454939
kernel = self._make_autotune_kernel()
49464940

49474941
kernel.reset()
49484942
torch._dynamo.reset()
4949-
HelionTemplateBuffer._fusion_config_cache.clear()
4943+
4944+
if supports_torch_compile_fusion():
4945+
from helion._compiler._inductor.template_buffer import HelionTemplateBuffer
4946+
4947+
HelionTemplateBuffer._fusion_config_cache.clear()
49504948

49514949
x = torch.randn(128, device=DEVICE, dtype=torch.float32)
49524950
y = torch.randn(128, device=DEVICE, dtype=torch.float32)
@@ -4989,13 +4987,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
49894987
f"Fusion cache entry {fusion_key!r} must be a Config",
49904988
)
49914989

4990+
@requires_fusion_support
49924991
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
49934992
def test_autotune_epilogue_only_fusion(self):
49944993
"""Fusion-aware autotuning works with epilogue only (no prologue)."""
4995-
if not supports_torch_compile_fusion():
4996-
self.skipTest(
4997-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
4998-
)
49994994

50004995
kernel = self._make_autotune_kernel()
50014996
captured_codes, patch_ctx = self._make_code_capture()
@@ -5023,13 +5018,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
50235018
for code in fused:
50245019
self.assertNotIn(pp, code, "Must NOT have prologue in epilogue-only test")
50255020

5021+
@requires_fusion_support
50265022
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
50275023
def test_autotune_prologue_only_fusion(self):
50285024
"""Fusion-aware autotuning works with prologue only (no epilogue)."""
5029-
if not supports_torch_compile_fusion():
5030-
self.skipTest(
5031-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
5032-
)
50335025

50345026
kernel = self._make_autotune_kernel()
50355027
captured_codes, patch_ctx = self._make_code_capture()
@@ -5059,13 +5051,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
50595051
for code in fused:
50605052
self.assertNotIn(ep, code, "Must NOT have epilogue in prologue-only test")
50615053

5054+
@requires_fusion_support
50625055
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
50635056
def test_autotune_bare_kernel_no_prologue_epilogue(self):
50645057
"""Fusion-aware autotuning does not break when Inductor has no prologue or epilogue to fuse."""
5065-
if not supports_torch_compile_fusion():
5066-
self.skipTest(
5067-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
5068-
)
50695058

50705059
kernel = self._make_autotune_kernel()
50715060

@@ -5125,6 +5114,11 @@ def k_add_no_configs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
51255114

51265115
# torch.compile with fusion — must not crash and must not silently
51275116
# reuse the cached unfused config.
5117+
torch._dynamo.reset()
5118+
5119+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5120+
return torch.relu(k_add_no_configs(x, y)) + 1.0
5121+
51285122
from helion._compiler._inductor.template_buffer import _FusionAutotuneAdapter
51295123

51305124
bench_compile_called: list[bool] = []
@@ -5134,11 +5128,6 @@ def tracking_bench(adapter_self, config=None, **kwargs):
51345128
bench_compile_called.append(True)
51355129
return original_bench(adapter_self, config, **kwargs)
51365130

5137-
torch._dynamo.reset()
5138-
5139-
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5140-
return torch.relu(k_add_no_configs(x, y)) + 1.0
5141-
51425131
with patch.object(
51435132
_FusionAutotuneAdapter, "bench_compile_config", tracking_bench
51445133
):
@@ -5158,14 +5147,11 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
51585147
"bench_compile_config, not reuse the cached unfused config",
51595148
)
51605149

5150+
@requires_fusion_support
51615151
@skipIfTileIR("torch.compile missing kernel metadata on tileir")
51625152
@patch.object(k_rms_norm.settings, "torch_compile_fusion", True)
51635153
def test_inductor_output_code_has_helion_generated_triton_kernel(self):
51645154
"""Verify Helion-specific patterns appear in inductor output code."""
5165-
if not supports_torch_compile_fusion():
5166-
self.skipTest(
5167-
"torch.compile fusion requires ExternalTritonTemplateKernel support"
5168-
)
51695155

51705156
def f(x, weight, out_bias, res_bias):
51715157
x_processed = torch.relu(x) + 0.5

0 commit comments

Comments
 (0)