Skip to content

[Pallas] Use exact RDIM size instead of next-power-of-2#1954

Draft
norx1991 wants to merge 1 commit intomainfrom
yifeixu/pallas-rdim-no-pow2-fix
Draft

[Pallas] Use exact RDIM size instead of next-power-of-2#1954
norx1991 wants to merge 1 commit intomainfrom
yifeixu/pallas-rdim-no-pow2-fix

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 6, 2026

Summary

  • Pallas block refs use exact tensor dimensions, but _RDIM_SIZE was rounded to the next power of 2 (e.g., 1000→1024), causing ValueError: Incompatible shapes for broadcasting.
  • Add Backend.static_rdim_size() and Backend.dynamic_rdim_size_expr() which default to next-power-of-2 behavior, and override both in PallasBackend to return exact sizes.
  • Removes @xfailIfPallas from test_reduce_non_pow2 added in [Pallas] Add expected-failure test for non-power-of-2 RDIM size #1945.

Test plan

  • test_reduce_non_pow2 fails without fix: ValueError: Incompatible shapes for broadcasting: shapes=[(128, 1024), (128, 1000)]
  • test_reduce_non_pow2 passes with fix (_RDIM_SIZE_1 = 1000)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 6, 2026
Pallas block refs use exact tensor dimensions, so rounding RDIM_SIZE
to the next power of 2 (e.g., 1000→1024) causes shape mismatches.
Add Backend.static_rdim_size() and override it in PallasBackend to
return the exact size. Also override next_power_of_2_host_expr to
be a no-op for Pallas.

Removes @xfailIfPallas from test_reduce_non_pow2 added in #1945.
@norx1991 norx1991 force-pushed the yifeixu/pallas-rdim-no-pow2-fix branch from 3b39bfb to 28c1f75 Compare April 6, 2026 05:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant