Skip to content

Commit 3b39bfb

Browse files
committed
[Pallas] Use exact RDIM size instead of next-power-of-2
Pallas block refs use exact tensor dimensions, so rounding RDIM_SIZE to the next power of 2 (e.g., 1000→1024) causes shape mismatches. Add Backend.static_rdim_size() and override it in PallasBackend to return the exact size. Also override next_power_of_2_host_expr to be a no-op for Pallas. Removes @xfailIfPallas from test_reduce_non_pow2 added in #1945.
1 parent b7b43b5 commit 3b39bfb

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

helion/_compiler/backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,12 @@ def next_power_of_2_host_expr(self, expr: str) -> str:
284284
"""Generate a host-side next-power-of-2 expression."""
285285
raise exc.BackendUnsupported(self.name, "next_power_of_2")
286286

287+
def static_rdim_size(self, numel: int) -> int:
288+
"""Return the RDIM block size for a statically known reduction dimension."""
289+
from torch._inductor.runtime.runtime_utils import next_power_of_2
290+
291+
return next_power_of_2(numel)
292+
287293
def reduction_combine_expr(
288294
self,
289295
reduction_type: str,
@@ -1121,6 +1127,15 @@ def reduction_index_expr(
11211127
def reduction_index_zero_expr(self, dtype: str) -> str:
11221128
return f"jnp.zeros([0], dtype={dtype})"
11231129

1130+
def next_power_of_2_host_expr(self, expr: str) -> str:
1131+
# Pallas block refs already have the exact tensor dimension size,
1132+
# so RDIM_SIZE must match the actual dimension (no power-of-2 rounding).
1133+
# Rounding up would create index arrays larger than the block ref.
1134+
return expr
1135+
1136+
def static_rdim_size(self, numel: int) -> int:
1137+
return numel
1138+
11241139
def adjust_block_size_constraints(
11251140
self,
11261141
block_specs: list[object],

helion/_compiler/reduction_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def codegen_preamble(self, state: CodegenState) -> None:
407407
if isinstance(numel, sympy.Integer):
408408
# Static size - issue statement immediately
409409
stmt = statement_from_string(
410-
f"{block_size_var} = {next_power_of_2(int(numel))}"
410+
f"{block_size_var} = {backend.static_rdim_size(int(numel))}"
411411
)
412412
state.codegen.host_statements.append(stmt)
413413
else:

test/test_pallas.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from helion._testing import code_and_output
1212
from helion._testing import onlyBackends
1313
from helion._testing import skipUnlessPallas
14-
from helion._testing import xfailIfPallas
1514
import helion.language as hl
1615

1716

@@ -540,7 +539,6 @@ def test_emit_pipeline_loop_order(self) -> None:
540539
expected = (x.float() @ y.float() + bias.float()).to(torch.bfloat16)
541540
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
542541

543-
@xfailIfPallas("RDIM_SIZE rounded to next power of 2 causes shape mismatch")
544542
def test_reduce_non_pow2(self) -> None:
545543
"""Reduction over non-power-of-2 dim should use exact size, not rounded."""
546544
x = torch.randn(128, 1000, device=DEVICE, dtype=torch.float32)

0 commit comments

Comments
 (0)