Skip to content

[Pallas] Fix fori_loop multi-dim index decomposition#1917

Draft
norx1991 wants to merge 3 commits intomainfrom
yifeixu/pallas-fori-loop-multidim
Draft

[Pallas] Fix fori_loop multi-dim index decomposition#1917
norx1991 wants to merge 3 commits intomainfrom
yifeixu/pallas-fori-loop-multidim

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 2, 2026

Summary

When _codegen_fori_loop handles a multi-dimensional inner loop (e.g. hl.tile([m, n])), the iteration was flattened into a single jax.lax.fori_loop with divmod to recover per-dimension indices. Previously, the raw flattened loop variable _j was used directly for all dimensions, producing incorrect indices for 2+ inner dims.

This PR fixes the indexing and uses nested fori_loop calls (one per dimension) instead of flat divmod decomposition. Benchmarking on TPU v7 showed nested fori_loop is ~5-10% faster than divmod.

Generated code for a 2D inner loop now looks like:

def _fori_body_0(_j0, _):
    def _fori_body_1(_j1, _):
        # DMA copies and compute using _j0, _j1 directly
    jax.lax.fori_loop(0, grid_n, _fori_body_1, None)
jax.lax.fori_loop(0, grid_m, _fori_body_0, None)

This issue was discovered while working toward enabling rms_norm_bwd on the Pallas/TPU backend, which uses a 2D inner tile.

When a fori_loop has multiple inner block_ids, the iteration is
flattened into a single loop. Previously the raw loop variable was
used for all dimensions' DMA slicing and offset expressions, which
is incorrect for 2+ inner dims. Decompose the flattened index into
per-dimension indices (row-major: divmod by suffix grid sizes).
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 2, 2026
@norx1991 norx1991 marked this pull request as ready for review April 2, 2026 23:32
@v0i0
Copy link
Copy Markdown
Contributor

v0i0 commented Apr 2, 2026

could you check if there's a perf difference between doing this w/ div&mod vs generating (a) nested fori (b) having running counters that get reset in a standalone jax-pallas sample?

@norx1991
Copy link
Copy Markdown
Contributor Author

norx1991 commented Apr 3, 2026

could you check if there's a perf difference between doing this w/ div&mod vs generating (a) nested fori (b) having running counters that get reset in a standalone jax-pallas sample?

Good catch! The result is here. Both are faster than div/mod. Which one would you suggest? Seems like nested fori_loop will be the more traditional path.

@v0i0
Copy link
Copy Markdown
Contributor

v0i0 commented Apr 3, 2026

could you check if there's a perf difference between doing this w/ div&mod vs generating (a) nested fori (b) having running counters that get reset in a standalone jax-pallas sample?

Good catch! The result is here. Both are faster than div/mod. Which one would you suggest? Seems like nested fori_loop will be the more traditional path.

nested fori seems the most natural

…sition

Benchmarking on TPU v7 showed nested fori_loop is ~5-10% faster than
flat fori_loop with divmod index decomposition. Replace the single
flattened fori_loop with nested fori_loop calls — one per dimension —
where each dimension's loop variable is used directly instead of being
recovered via division and modulo.
@norx1991 norx1991 marked this pull request as draft April 3, 2026 23:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants