Skip to content

Commit 28c1f75

Browse files
committed
[Pallas] Use exact RDIM size instead of next-power-of-2
Pallas block refs use exact tensor dimensions, so rounding RDIM_SIZE to the next power of 2 (e.g., 1000→1024) causes shape mismatches. Add Backend.static_rdim_size() and override it in PallasBackend to return the exact size. Also override next_power_of_2_host_expr to be a no-op for Pallas. Removes @xfailIfPallas from test_reduce_non_pow2 added in #1945.
1 parent b7b43b5 commit 28c1f75

File tree

4 files changed

+25
-5
lines changed

4 files changed

+25
-5
lines changed

helion/_compiler/backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,20 @@ def next_power_of_2_host_expr(self, expr: str) -> str:
284284
"""Generate a host-side next-power-of-2 expression."""
285285
raise exc.BackendUnsupported(self.name, "next_power_of_2")
286286

287+
def static_rdim_size(self, numel: int) -> int:
288+
"""Return the RDIM block size for a statically known reduction dimension."""
289+
from torch._inductor.runtime.runtime_utils import next_power_of_2
290+
291+
return next_power_of_2(numel)
292+
293+
def dynamic_rdim_size_expr(self, expr: str) -> str:
294+
"""Generate a host-side expression for RDIM size from a dynamic dimension.
295+
296+
By default delegates to next_power_of_2_host_expr. Backends like Pallas
297+
that need exact sizes can override to return the expression unchanged.
298+
"""
299+
return self.next_power_of_2_host_expr(expr)
300+
287301
def reduction_combine_expr(
288302
self,
289303
reduction_type: str,
@@ -1121,6 +1135,14 @@ def reduction_index_expr(
11211135
def reduction_index_zero_expr(self, dtype: str) -> str:
11221136
return f"jnp.zeros([0], dtype={dtype})"
11231137

1138+
def static_rdim_size(self, numel: int) -> int:
1139+
# Pallas block refs use exact tensor dimensions, so RDIM_SIZE must
1140+
# match (no power-of-2 rounding that would exceed the block ref).
1141+
return numel
1142+
1143+
def dynamic_rdim_size_expr(self, expr: str) -> str:
1144+
return expr
1145+
11241146
def adjust_block_size_constraints(
11251147
self,
11261148
block_specs: list[object],

helion/_compiler/device_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def flush_deferred_rdim_defs(self, codegen: GenerateAST) -> None:
852852
for var_name, expr in self.deferred_rdim_defs:
853853
expr_str = HostFunction.current().sympy_expr(expr)
854854
stmt = statement_from_string(
855-
f"{var_name} = {backend.next_power_of_2_host_expr(expr_str)}"
855+
f"{var_name} = {backend.dynamic_rdim_size_expr(expr_str)}"
856856
)
857857
codegen.host_statements.append(stmt)
858858
self.deferred_rdim_defs.clear()

helion/_compiler/reduction_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def codegen_preamble(self, state: CodegenState) -> None:
407407
if isinstance(numel, sympy.Integer):
408408
# Static size - issue statement immediately
409409
stmt = statement_from_string(
410-
f"{block_size_var} = {next_power_of_2(int(numel))}"
410+
f"{block_size_var} = {backend.static_rdim_size(int(numel))}"
411411
)
412412
state.codegen.host_statements.append(stmt)
413413
else:
@@ -422,7 +422,7 @@ def codegen_preamble(self, state: CodegenState) -> None:
422422
# No dependencies - issue statement immediately
423423
expr_str = HostFunction.current().sympy_expr(numel)
424424
stmt = statement_from_string(
425-
f"{block_size_var} = {backend.next_power_of_2_host_expr(expr_str)}"
425+
f"{block_size_var} = {backend.dynamic_rdim_size_expr(expr_str)}"
426426
)
427427
state.codegen.host_statements.append(stmt)
428428
current_grid = state.codegen.current_grid_state

test/test_pallas.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from helion._testing import code_and_output
1212
from helion._testing import onlyBackends
1313
from helion._testing import skipUnlessPallas
14-
from helion._testing import xfailIfPallas
1514
import helion.language as hl
1615

1716

@@ -540,7 +539,6 @@ def test_emit_pipeline_loop_order(self) -> None:
540539
expected = (x.float() @ y.float() + bias.float()).to(torch.bfloat16)
541540
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
542541

543-
@xfailIfPallas("RDIM_SIZE rounded to next power of 2 causes shape mismatch")
544542
def test_reduce_non_pow2(self) -> None:
545543
"""Reduction over non-power-of-2 dim should use exact size, not rounded."""
546544
x = torch.randn(128, 1000, device=DEVICE, dtype=torch.float32)

0 commit comments

Comments
 (0)