Skip to content

Commit d848c74

Browse files
Drop the escape in canUseDPP.
With the escape, the DPP path could be taken when blockSize > maxActiveReductionThreads, leaving extra threads with nrtid >= 1 (out of the valid [0, 1) range) that would compute out-of-bounds LDS coordinates. Tuning data across f16/f32/int8 attention configs shows nrDimProd is always >= 16, so this escape was never actually triggered and removing it does not change behavior for any current configuration.
1 parent aac9312 commit d848c74

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,8 +1360,10 @@ struct BlockwiseReduceRewritePattern
13601360
// 2. More than 1 reduction thread (at least 2 for cross-lane work)
13611361
// 3. partial_r > 2 (DPP overhead not justified for partial_r=2)
13621362
// 4. Reduction threads fit within a single wave
1363-
// 5. blockSize == clusterSize * nonReductionDimSizeProduct, or
1364-
// nonReductionDimSizeProduct == 1.
1363+
// 5. Exact thread packing: blockSize == clusterSize *
1364+
// nonReductionDimSizeProduct. This guarantees every thread maps to
1365+
// a valid (nrtid, rtid) pair, so LDS coordinates derived from them
1366+
// are in-bounds.
13651367
// Otherwise, fall back to LDS-based tree reduction.
13661368
int64_t maxActiveReductionThreads = threadViewShape[rTidDim];
13671369
int64_t clusterSize = llvm::PowerOf2Ceil(maxActiveReductionThreads);
@@ -1370,8 +1372,7 @@ struct BlockwiseReduceRewritePattern
13701372
(maxActiveReductionThreads > 1) && (partialR > 2) &&
13711373
(maxActiveReductionThreads <= waveSize) &&
13721374
(blockSize == maxActiveReductionThreads *
1373-
nonReductionDimSizeProduct ||
1374-
nonReductionDimSizeProduct == 1);
1375+
nonReductionDimSizeProduct);
13751376
// DPP path: contiguous threads reduce together (rtid = tid % cluster).
13761377
// Tree path: scattered layout (rtid = tid /
13771378
// nonReductionDimSizeProduct).

0 commit comments

Comments
 (0)