Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,20 @@ def next_power_of_2_host_expr(self, expr: str) -> str:
"""Generate a host-side next-power-of-2 expression."""
raise exc.BackendUnsupported(self.name, "next_power_of_2")

def static_rdim_size(self, numel: int) -> int:
"""Return the RDIM block size for a statically known reduction dimension."""
from torch._inductor.runtime.runtime_utils import next_power_of_2

return next_power_of_2(numel)

def dynamic_rdim_size_expr(self, expr: str) -> str:
"""Generate a host-side expression for RDIM size from a dynamic dimension.

By default delegates to next_power_of_2_host_expr. Backends like Pallas
that need exact sizes can override to return the expression unchanged.
"""
return self.next_power_of_2_host_expr(expr)

def reduction_combine_expr(
self,
reduction_type: str,
Expand Down Expand Up @@ -1121,6 +1135,14 @@ def reduction_index_expr(
def reduction_index_zero_expr(self, dtype: str) -> str:
return f"jnp.zeros([0], dtype={dtype})"

def static_rdim_size(self, numel: int) -> int:
# Pallas block refs use exact tensor dimensions, so RDIM_SIZE must
# match (no power-of-2 rounding that would exceed the block ref).
return numel

def dynamic_rdim_size_expr(self, expr: str) -> str:
return expr

def adjust_block_size_constraints(
self,
block_specs: list[object],
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ def flush_deferred_rdim_defs(self, codegen: GenerateAST) -> None:
for var_name, expr in self.deferred_rdim_defs:
expr_str = HostFunction.current().sympy_expr(expr)
stmt = statement_from_string(
f"{var_name} = {backend.next_power_of_2_host_expr(expr_str)}"
f"{var_name} = {backend.dynamic_rdim_size_expr(expr_str)}"
)
codegen.host_statements.append(stmt)
self.deferred_rdim_defs.clear()
Expand Down
4 changes: 2 additions & 2 deletions helion/_compiler/reduction_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def codegen_preamble(self, state: CodegenState) -> None:
if isinstance(numel, sympy.Integer):
# Static size - issue statement immediately
stmt = statement_from_string(
f"{block_size_var} = {next_power_of_2(int(numel))}"
f"{block_size_var} = {backend.static_rdim_size(int(numel))}"
)
state.codegen.host_statements.append(stmt)
else:
Expand All @@ -422,7 +422,7 @@ def codegen_preamble(self, state: CodegenState) -> None:
# No dependencies - issue statement immediately
expr_str = HostFunction.current().sympy_expr(numel)
stmt = statement_from_string(
f"{block_size_var} = {backend.next_power_of_2_host_expr(expr_str)}"
f"{block_size_var} = {backend.dynamic_rdim_size_expr(expr_str)}"
)
state.codegen.host_statements.append(stmt)
current_grid = state.codegen.current_grid_state
Expand Down
2 changes: 0 additions & 2 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from helion._testing import code_and_output
from helion._testing import onlyBackends
from helion._testing import skipUnlessPallas
from helion._testing import xfailIfPallas
import helion.language as hl


Expand Down Expand Up @@ -540,7 +539,6 @@ def test_emit_pipeline_loop_order(self) -> None:
expected = (x.float() @ y.float() + bias.float()).to(torch.bfloat16)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)

@xfailIfPallas("RDIM_SIZE rounded to next power of 2 causes shape mismatch")
def test_reduce_non_pow2(self) -> None:
"""Reduction over non-power-of-2 dim should use exact size, not rounded."""
x = torch.randn(128, 1000, device=DEVICE, dtype=torch.float32)
Expand Down
Loading