[Pallas] Fix fori_loop multi-dim index decomposition#1917
[Pallas] Fix fori_loop multi-dim index decomposition#1917
Conversation
When a fori_loop has multiple inner block_ids, the iteration is flattened into a single loop. Previously the raw loop variable was used for all dimensions' DMA slicing and offset expressions, which is incorrect for 2+ inner dims. Decompose the flattened index into per-dimension indices (row-major: divmod by suffix grid sizes).
|
could you check if there's a perf difference between doing this w/ div&mod vs generating (a) nested fori (b) having running counters that get reset in a standalone jax-pallas sample? |
Good catch! The result is here. Both are faster than div/mod. Which one would you suggest? Seems like nested fori_loop will be the more traditional path. |
nested fori seems the most natural |
…sition Benchmarking on TPU v7 showed nested fori_loop is ~5-10% faster than flat fori_loop with divmod index decomposition. Replace the single flattened fori_loop with nested fori_loop calls — one per dimension — where each dimension's loop variable is used directly instead of being recovered via division and modulo.
Summary
When
_codegen_fori_loophandles a multi-dimensional inner loop (e.g.hl.tile([m, n])), the iteration was flattened into a singlejax.lax.fori_loopwith divmod to recover per-dimension indices. Previously, the raw flattened loop variable_jwas used directly for all dimensions, producing incorrect indices for 2+ inner dims.This PR fixes the indexing and uses nested
fori_loopcalls (one per dimension) instead of flat divmod decomposition. Benchmarking on TPU v7 showed nested fori_loop is ~5-10% faster than divmod.Generated code for a 2D inner loop now looks like:
This issue was discovered while working toward enabling
rms_norm_bwdon the Pallas/TPU backend, which uses a 2D inner tile.