|
11 | 11 | from helion._testing import code_and_output |
12 | 12 | from helion._testing import onlyBackends |
13 | 13 | from helion._testing import skipUnlessPallas |
| 14 | +from helion._testing import xfailIfPallas |
14 | 15 | import helion.language as hl |
15 | 16 |
|
16 | 17 |
|
@@ -210,6 +211,23 @@ def pallas_attention( |
210 | 211 | return out.view(q_in.size()) |
211 | 212 |
|
212 | 213 |
|
| 214 | +@helion.kernel(backend="pallas", static_shapes=True) |
| 215 | +def pallas_reduce_non_pow2(x: torch.Tensor) -> torch.Tensor: |
| 216 | + """Softmax over a non-power-of-2 reduction dim. |
| 217 | +
|
| 218 | + Uses amax + exp + sum which forces explicit index/mask generation, |
| 219 | + exercising the RDIM_SIZE code path. |
| 220 | + """ |
| 221 | + n, _m = x.size() |
| 222 | + out = torch.empty_like(x) |
| 223 | + for tile_n in hl.tile(n): |
| 224 | + row = x[tile_n, :] |
| 225 | + max_val = torch.amax(row, dim=-1, keepdim=True) |
| 226 | + exp_val = torch.exp(row - max_val) |
| 227 | + out[tile_n, :] = exp_val / torch.sum(exp_val, dim=-1, keepdim=True) |
| 228 | + return out |
| 229 | + |
| 230 | + |
213 | 231 | @onlyBackends(["triton", "pallas"]) |
214 | 232 | @skipUnlessPallas("JAX/Pallas TPU not available") |
215 | 233 | class TestPallas(TestCase): |
@@ -503,5 +521,16 @@ def test_attention_emit_pipeline_non_divisible(self) -> None: |
503 | 521 | torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2) |
504 | 522 |
|
505 | 523 |
|
| 524 | + @xfailIfPallas("RDIM_SIZE rounded to next power of 2 causes shape mismatch") |
| 525 | + def test_reduce_non_pow2(self) -> None: |
| 526 | + """Reduction over non-power-of-2 dim should use exact size, not rounded.""" |
| 527 | + x = torch.randn(128, 1000, device=DEVICE, dtype=torch.float32) |
| 528 | + code, result = code_and_output( |
| 529 | + pallas_reduce_non_pow2, (x,), block_size=128 |
| 530 | + ) |
| 531 | + expected = torch.nn.functional.softmax(x, dim=-1) |
| 532 | + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) |
| 533 | + |
| 534 | + |
506 | 535 | if __name__ == "__main__": |
507 | 536 | unittest.main() |
0 commit comments