Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def adjust_block_size_constraints(

# Tiling size for 1D arrays. Mosaic lowering enforces that rank-1
# BlockSpec block shapes are a multiple of 128 * (32 // bitwidth).
tiling_1d = 128 * (32 // min_element_bits)
tiling_1d = 128 * 32 // min_element_bits

# Map block_id -> minimum dim_from_end across all tensors
min_dim_from_end: dict[int, int] = {}
Expand Down Expand Up @@ -1329,7 +1329,7 @@ def _compute_block_spec_info(
dim_size = tensor.shape[d]
if tensor.ndim == 1 and isinstance(dim_size, int):
bitwidth = tensor.dtype.itemsize * 8
tiling_1d = 128 * (32 // bitwidth)
tiling_1d = 128 * 32 // bitwidth
if bs != dim_size and bs % tiling_1d != 0:
return None
block_shape.append(bs)
Expand Down
19 changes: 19 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,25 @@ def test_reduce_non_pow2(self) -> None:
expected = torch.nn.functional.softmax(x, dim=-1)
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)

def test_int64_1d_tensor_block_spec(self) -> None:
"""1D int64 tensor must not crash block spec computation.

Regression: tiling_1d = 128 * (32 // 64) was 0 for int64,
causing ZeroDivisionError in bs % tiling_1d. Fixed by
rewriting as 128 * 32 // bitwidth = 64.
"""
x = torch.arange(256, device=DEVICE, dtype=torch.int64)
y = torch.arange(256, device=DEVICE, dtype=torch.int64)
# Verify codegen succeeds (no ZeroDivisionError).
# We only test code generation, not execution, because JAX
# truncates int64 → int32 without JAX_ENABLE_X64.
from helion import Config

bound = add_kernel.bind((x, y))
config = Config(block_sizes=[64])
code = bound.to_triton_code(config)
self.assertIn("_BLOCK_SIZE_0", code)


if __name__ == "__main__":
unittest.main()
Loading