Skip to content

Commit d852bde

Browse files
committed
[Pallas] Add expected-failure test for non-power-of-2 RDIM size
RDIM_SIZE is rounded to next power of 2, but Pallas block refs use the exact dimension size, causing shape mismatches (e.g., 1000 vs 1024). Adds test_reduce_non_pow2 as expectedFailure to document the bug.
1 parent f2b111e commit d852bde

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/test_pallas.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from helion._testing import code_and_output
1212
from helion._testing import onlyBackends
1313
from helion._testing import skipUnlessPallas
14+
from helion._testing import xfailIfPallas
1415
import helion.language as hl
1516

1617

@@ -210,6 +211,23 @@ def pallas_attention(
210211
return out.view(q_in.size())
211212

212213

214+
@helion.kernel(backend="pallas", static_shapes=True)
215+
def pallas_reduce_non_pow2(x: torch.Tensor) -> torch.Tensor:
216+
"""Softmax over a non-power-of-2 reduction dim.
217+
218+
Uses amax + exp + sum which forces explicit index/mask generation,
219+
exercising the RDIM_SIZE code path.
220+
"""
221+
n, _m = x.size()
222+
out = torch.empty_like(x)
223+
for tile_n in hl.tile(n):
224+
row = x[tile_n, :]
225+
max_val = torch.amax(row, dim=-1, keepdim=True)
226+
exp_val = torch.exp(row - max_val)
227+
out[tile_n, :] = exp_val / torch.sum(exp_val, dim=-1, keepdim=True)
228+
return out
229+
230+
213231
@onlyBackends(["triton", "pallas"])
214232
@skipUnlessPallas("JAX/Pallas TPU not available")
215233
class TestPallas(TestCase):
@@ -503,5 +521,16 @@ def test_attention_emit_pipeline_non_divisible(self) -> None:
503521
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)
504522

505523

524+
@xfailIfPallas("RDIM_SIZE rounded to next power of 2 causes shape mismatch")
525+
def test_reduce_non_pow2(self) -> None:
526+
"""Reduction over non-power-of-2 dim should use exact size, not rounded."""
527+
x = torch.randn(128, 1000, device=DEVICE, dtype=torch.float32)
528+
code, result = code_and_output(
529+
pallas_reduce_non_pow2, (x,), block_size=128
530+
)
531+
expected = torch.nn.functional.softmax(x, dim=-1)
532+
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)
533+
534+
506535
if __name__ == "__main__":
507536
unittest.main()

0 commit comments

Comments
 (0)