Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/matrix.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@
"pytorch-version": "pytorch-nightly",
"alias": "b200",
"backend": "cute"
},
{
"runner": "linux.2xlarge",
"python-version": "3.12",
"ref-eager": false,
"image": "",
"runtime-version": "cpu",
"container-options": "",
"pytorch-version": "pytorch-nightly",
"alias": "pallas-interpret",
"backend": "pallas"
}
]
}
9 changes: 8 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ jobs:
# Verify
python -c "from torch_tpu import api; print(f'TPU device: {api.tpu_device()}')"

- name: Install Pallas interpret dependencies
if: matrix.alias == 'pallas-interpret'
run: |
source .venv/bin/activate
uv pip install 'jax==0.9.2' 'jaxlib==0.9.2' 'absl-py'

- name: Install CUTLASS CuTe DSL
if: matrix.backend == 'cute'
run: |
Expand Down Expand Up @@ -249,14 +255,15 @@ jobs:
if [[ "${{ matrix.expecttest-accept }}" == "true" ]]; then export EXPECTTEST_ACCEPT=1; fi
if [[ "${{ matrix.ref-eager }}" == "true" ]]; then export HELION_INTERPRET=1; fi
if [[ "${{ matrix.backend }}" == "tileir" ]]; then export ENABLE_TILE=1; fi
if [[ "${{ matrix.alias }}" == "pallas-interpret" ]]; then export HELION_PALLAS_INTERPRET=1; fi
export HELION_BACKEND=${{ matrix.backend }}
# -rf: print failed tests
# --timeout: max allowed time for each test
PARALLEL="-n4"
if [[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]]; then
TEST_PATH="test/test_examples_dist.py"
EXTRA_FLAGS="-rs"
elif [[ "${{ matrix.alias }}" == "tpu" ]]; then
elif [[ "${{ matrix.alias }}" == "tpu" || "${{ matrix.alias }}" == "pallas-interpret" ]]; then
TEST_PATH="."
EXTRA_FLAGS="--ignore=test/test_examples_dist.py"
PARALLEL=""
Expand Down
18 changes: 17 additions & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def readouterr(self) -> tuple[str, str]:

def is_cuda() -> bool:
"""Return True if running on CUDA (NVIDIA GPU)."""
if _get_backend() == "pallas":
return False
return _get_triton_backend() == "cuda" and torch.cuda.is_available()


Expand Down Expand Up @@ -292,6 +294,20 @@ def xfailIfPallas(reason: str) -> Callable[[Callable], Callable]:
return xfailIfFn(lambda: _get_backend() == "pallas", reason)


def xfailIfPallasTpu(reason: str) -> Callable[[Callable], Callable]:
"""Mark test as expected failure if running with pallas on real TPU (not interpret mode)"""
return xfailIfFn(
lambda: _get_backend() == "pallas" and not is_pallas_interpret(), reason
)


def xfailIfPallasInterpret(reason: str) -> Callable[[Callable], Callable]:
"""Mark test as expected failure if running with pallas in interpret mode"""
return xfailIfFn(
lambda: _get_backend() == "pallas" and is_pallas_interpret(), reason
)


def skipUnlessAMDCDNA(reason: str) -> Callable[[Callable], Callable]:
"""Skip test unless running on AMD CDNA architecture."""
from helion._compat import supports_amd_cdna_tunables
Expand Down Expand Up @@ -355,7 +371,7 @@ def wrapper(cls: type[unittest.TestCase]) -> type[unittest.TestCase]:
def skipUnlessTensorDescriptor(reason: str) -> Callable[[Callable], Callable]:
"""Skip test unless tensor descriptors are supported."""
# Defers check to test execution time to avoid CUDA init during pytest-xdist collection.
return skipIfFn(lambda: not supports_tensor_descriptor(), reason)
return skipIfFn(lambda: not is_cuda() or not supports_tensor_descriptor(), reason)


def skipUnlessTf32Supported(
Expand Down
11 changes: 4 additions & 7 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,20 +897,17 @@ def default_pallas_fori_launcher(


def _torch_to_jax(t: torch.Tensor) -> object:
"""Convert a torch.Tensor to a JAX array via numpy (for interpret mode on CPU)."""
"""Convert a torch.Tensor to a JAX array via DLPack (for interpret mode on CPU)."""
import jax.numpy as jnp
import numpy as np

return jnp.array(np.asarray(t.detach().cpu()))
return jnp.from_dlpack(t.detach().cpu())


def _jax_to_torch(
arr: object, *, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
"""Convert a JAX array back to a torch.Tensor via numpy (for interpret mode on CPU)."""
import numpy as np

return torch.from_numpy(np.asarray(arr)).to(dtype=dtype, device=device)
"""Convert a JAX array back to a torch.Tensor via DLPack (for interpret mode on CPU)."""
return torch.from_dlpack(arr).to(dtype=dtype, device=device)


def _torch_dtype_to_cutlass(dtype: torch.dtype) -> object:
Expand Down
3 changes: 2 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from helion._testing import skipIfXPU
from helion._testing import xfailIfCute
from helion._testing import xfailIfPallas
from helion._testing import xfailIfPallasTpu

_orig_matmul_fp32_precision: str = "none"
_orig_cudnn_fp32_precision: str = "none"
Expand Down Expand Up @@ -1372,7 +1373,7 @@ def test_jsd(self):
num_stages=3,
)

@xfailIfPallas("operation not supported on TPU")
@xfailIfPallasTpu("operation not supported on TPU")
def test_kl_div(self):
if _get_backend() == "cute" and "B200" in get_nvidia_gpu_model():
pytest.xfail("CuTe KL-div example still launches out of resources on B200")
Expand Down
9 changes: 5 additions & 4 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from helion._testing import skipIfRefEager
from helion._testing import skipIfTileIR
from helion._testing import xfailIfPallas
from helion._testing import xfailIfPallasTpu
import helion.language as hl

datadir = Path(__file__).parent / "data"
Expand Down Expand Up @@ -163,7 +164,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
with self.assertRaises(helion.exc.StatementNotSupported):
kernel.bind((x,))

@xfailIfPallas("large 4D tensors may exceed TPU VMEM")
@xfailIfPallasTpu("large 4D tensors may exceed TPU VMEM")
@skipIfLowVRAM("Test requires high VRAM for [128, 128, 128, 128] tensors")
def test_3d_device_loop0(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
Expand All @@ -174,7 +175,7 @@ def test_3d_device_loop0(self):
)
torch.testing.assert_close(result, torch.sin(args[0]))

@xfailIfPallas("large 4D tensors may exceed TPU VMEM")
@xfailIfPallasTpu("large 4D tensors may exceed TPU VMEM")
@skipIfLowVRAM("Test requires high VRAM for [128, 128, 128, 128] tensors")
def test_3d_device_loop1(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
Expand All @@ -199,7 +200,7 @@ def test_3d_device_loop2(self):
)
torch.testing.assert_close(result, torch.sin(args[0]))

@xfailIfPallas("large 4D tensors may exceed TPU VMEM")
@xfailIfPallasTpu("large 4D tensors may exceed TPU VMEM")
@patch.object(_compat, "_supports_tensor_descriptor", lambda: False)
@skipIfLowVRAM("Test requires high VRAM for [128, 128, 128, 128] tensors")
@skipIfTileIR("TileIR does not support block_ptr indexing")
Expand Down Expand Up @@ -923,7 +924,7 @@ def test_range_num_stages(self):
code3,
)

@xfailIfPallas("range_num_stages is Triton-specific")
@xfailIfPallasTpu("range_num_stages is Triton-specific")
@skipIfTileIR("tileir backend will ignore `range_num_stages` hint")
@skipIfRefEager("not supported in ref eager mode")
def test_range_num_stages_preserved_without_aliasing(self):
Expand Down
4 changes: 4 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from helion._testing import code_and_output
from helion._testing import onlyBackends
from helion._testing import skipUnlessPallas
from helion._testing import xfailIfPallasInterpret
import helion.language as hl


Expand Down Expand Up @@ -480,6 +481,9 @@ def test_attention_fori_loop_correctness(self) -> None:
).to(device=DEVICE)
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)

@xfailIfPallasInterpret(
"JAX interpret mode cannot discharge dynamic_slice with traced sizes",
)
def test_attention_emit_pipeline_non_divisible(self) -> None:
"""Test emit_pipeline with seq_kv not divisible by block_k.

Expand Down
4 changes: 2 additions & 2 deletions test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from helion._testing import skipIfTileIR
from helion._testing import skipUnlessTensorDescriptor
from helion._testing import xfailIfCute
from helion._testing import xfailIfPallas
from helion._testing import xfailIfPallasTpu
import helion.language as hl

if TYPE_CHECKING:
Expand Down Expand Up @@ -434,7 +434,7 @@ def layer_norm_fwd_repro(
)
torch.testing.assert_close(result1, result2, rtol=1e-3, atol=1e-3)

@xfailIfPallas("fp16/bf16 1D tensors hit TPU Mosaic sublane alignment error")
@xfailIfPallasTpu("fp16/bf16 1D tensors hit TPU Mosaic sublane alignment error")
@skipIfTileIR("TileIR does not support log1p")
def test_fp16_math_ops_fp32_fallback(self):
"""Test that mathematical ops with fp16/bfloat16 inputs now work via fp32 fallback."""
Expand Down
Loading