Skip to content

[Pallas] Add expected-failure test for non-power-of-2 RDIM size#1945

Merged
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-rdim-no-pow2-test
Apr 4, 2026
Merged

[Pallas] Add expected-failure test for non-power-of-2 RDIM size#1945
norx1991 merged 1 commit intomainfrom
yifeixu/pallas-rdim-no-pow2-test

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 3, 2026

Summary

  • Adds test_reduce_non_pow2 to test/test_pallas.py to document a Pallas bug where _RDIM_SIZE is rounded to the next power of 2 (e.g., 1000→1024), but Pallas block refs use the exact dimension size, causing shape mismatches.
  • The test uses a softmax kernel (amax + exp + sum) with a non-power-of-2 reduction dim (1000) to force explicit index/mask generation, exercising the _RDIM_SIZE code path.
  • Marked with @xfailIfPallas since the fix has not landed yet.

Test plan

  • Verified test fails without fix: ValueError: Incompatible shapes for broadcasting: shapes=[(128, 1024), (128, 1000)]

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 3, 2026
@norx1991 norx1991 force-pushed the yifeixu/pallas-rdim-no-pow2-test branch 3 times, most recently from d852bde to a826a11 Compare April 3, 2026 23:00
@norx1991 norx1991 marked this pull request as ready for review April 3, 2026 23:10
RDIM_SIZE is rounded to next power of 2, but Pallas block refs use
the exact dimension size, causing shape mismatches (e.g., 1000 vs 1024).

Adds test_reduce_non_pow2 as expectedFailure to document the bug.
@norx1991 norx1991 force-pushed the yifeixu/pallas-rdim-no-pow2-test branch from a826a11 to 02c6b97 Compare April 3, 2026 23:38
@norx1991 norx1991 merged commit f28637d into main Apr 4, 2026
21 checks passed
@norx1991 norx1991 deleted the yifeixu/pallas-rdim-no-pow2-test branch April 4, 2026 05:55
norx1991 added a commit that referenced this pull request Apr 6, 2026
Pallas block refs use exact tensor dimensions, so rounding RDIM_SIZE
to the next power of 2 (e.g., 1000→1024) causes shape mismatches.
Add Backend.static_rdim_size() and override it in PallasBackend to
return the exact size. Also override next_power_of_2_host_expr to
be a no-op for Pallas.

Removes @xfailIfPallas from test_reduce_non_pow2 added in #1945.
norx1991 added a commit that referenced this pull request Apr 6, 2026
Pallas block refs use exact tensor dimensions, so rounding RDIM_SIZE
to the next power of 2 (e.g., 1000→1024) causes shape mismatches.
Add Backend.static_rdim_size() and override it in PallasBackend to
return the exact size. Also override next_power_of_2_host_expr to
be a no-op for Pallas.

Removes @xfailIfPallas from test_reduce_non_pow2 added in #1945.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants