Skip to content

Commit aac9312

Browse files
Tighten DPP eligibility to require exact thread-count match and use arch DB for wave size
- Change canUseDPP condition from >= to == for blockSize vs clusterSize * nonReductionDimSizeProduct to prevent potential out-of-bounds LDS writes by extra threads when blockSize exceeds the exact thread count needed for the DPP layout. - Replace hard-coded chipset major version heuristic in SubgroupReduceToDPP with rock::lookupArchInfo(chip).waveSize for more robust subgroup size derivation. - Update lowering_blockwise_broadcast_reduce test to use dimensions where blockSize == clusterSize * nrDimProd (8 == 2 * 4).
1 parent 91db7cc commit aac9312

3 files changed

Lines changed: 13 additions & 13 deletions

File tree

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,15 +1360,16 @@ struct BlockwiseReduceRewritePattern
13601360
// 2. More than 1 reduction thread (at least 2 for cross-lane work)
13611361
// 3. partial_r > 2 (DPP overhead not justified for partial_r=2)
13621362
// 4. Reduction threads fit within a single wave
1363-
// 5. Block has enough threads or non-reduction dim is trivial
1363+
// 5. blockSize == clusterSize * nonReductionDimSizeProduct, or
1364+
// nonReductionDimSizeProduct == 1.
13641365
// Otherwise, fall back to LDS-based tree reduction.
13651366
int64_t maxActiveReductionThreads = threadViewShape[rTidDim];
13661367
int64_t clusterSize = llvm::PowerOf2Ceil(maxActiveReductionThreads);
13671368
int64_t partialR = partialRegTensorShape[rDim];
13681369
bool canUseDPP = llvm::isPowerOf2_64(maxActiveReductionThreads) &&
13691370
(maxActiveReductionThreads > 1) && (partialR > 2) &&
13701371
(maxActiveReductionThreads <= waveSize) &&
1371-
(blockSize >= maxActiveReductionThreads *
1372+
(blockSize == maxActiveReductionThreads *
13721373
nonReductionDimSizeProduct ||
13731374
nonReductionDimSizeProduct == 1);
13741375
// DPP path: contiguous threads reduce together (rtid = tid % cluster).

mlir/lib/Dialect/Rock/Transforms/SubgroupReduceToDPP.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1919
#include "mlir/Dialect/GPU/Transforms/Passes.h"
2020
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
21+
#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
2122
#include "mlir/Pass/Pass.h"
2223
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2324

@@ -51,10 +52,7 @@ struct RockSubgroupReduceToDPPPass
5152
MLIRContext *ctx = &getContext();
5253
RewritePatternSet patterns(ctx);
5354

54-
unsigned subgroupSize = 64;
55-
if (maybeChipset->majorVersion >= 10) {
56-
subgroupSize = 32;
57-
}
55+
unsigned subgroupSize = rock::lookupArchInfo(chip).waveSize;
5856

5957
populateGpuBreakDownSubgroupReducePatterns(
6058
patterns, /*maxShuffleBitwidth=*/32, PatternBenefit(3));

mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,27 @@ func.func @rock_blockwise_reducesum_nr_threads_gt_blocksize(%input_reg : memref<
7373

7474
// -----
7575

76-
#inputView = #rock.transform_map<affine_map<(d0, d1) -> (d1, d0)> by [<PassThrough ["tid"] at [0] -> ["r"] at [1]>, <PassThrough ["iter"] at [1] -> ["nr_per_bid"] at [0]>] bounds = [10, 3] -> [3, 10]>
77-
#inputView_tid = #rock.transform_map<affine_map<(d0) -> (0, d0)> by [<Merge{1, 10} ["tid"] at [0] -> ["nr_per_bid", "r"] at [0, 1]>] bounds = [10] -> [1, 10]>
78-
#inputView_iter = #rock.transform_map<affine_map<(d0) -> (d0, 0)> by [<Merge{3, 1} ["iter"] at [0] -> ["nr_per_bid", "r"] at [0, 1]>] bounds = [3] -> [3, 1]>
76+
#inputView = #rock.transform_map<affine_map<(d0, d1) -> (d1, d0)> by [<PassThrough ["tid"] at [0] -> ["r"] at [1]>, <PassThrough ["iter"] at [1] -> ["nr_per_bid"] at [0]>] bounds = [8, 4] -> [4, 8]>
77+
#inputView_tid = #rock.transform_map<affine_map<(d0) -> (0, d0)> by [<Merge{1, 8} ["tid"] at [0] -> ["nr_per_bid", "r"] at [0, 1]>] bounds = [8] -> [1, 8]>
78+
#inputView_iter = #rock.transform_map<affine_map<(d0) -> (d0, 0)> by [<Merge{4, 1} ["iter"] at [0] -> ["nr_per_bid", "r"] at [0, 1]>] bounds = [4] -> [4, 1]>
7979
// CHECK-LABEL: func @rock_blockwise_reducesum_rthreads_fix
80-
func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<3xf32, #gpu.address_space<private>>, %output_reg : memref<3xf32, #gpu.address_space<private>>, %ws_lds : memref<30xf32, #gpu.address_space<workgroup>>) attributes{rock.arch = "##TOKEN_ARCH##", block_size = 10 : i32, grid_size = 2 : i32, rock.kernel} {
80+
func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<4xf32, #gpu.address_space<private>>, %output_reg : memref<4xf32, #gpu.address_space<private>>, %ws_lds : memref<32xf32, #gpu.address_space<workgroup>>) attributes{rock.arch = "##TOKEN_ARCH##", block_size = 8 : i32, grid_size = 2 : i32, rock.kernel} {
8181
// Compute rthread index and nr index from tid
82+
// blockSize=8, nrDimProd=4, rTid=2, cs=2 -> cs*nrDimProd=8==blockSize
8283
// CHECK-DAG: %[[TID:.*]] = rock.workitem_id : index
8384
// CHECK: %[[RTID:.*]] = arith.andi %[[TID]], %c1
8485
// CHECK: %[[NRTID:.*]] = arith.shrui %[[TID]], %c1
8586

86-
// Threadwise partial reduction uses rDimPerRThread=5
87+
// Threadwise partial reduction uses rDimPerRThread=4
8788
// CHECK: rock.transforming_for
88-
// CHECK-SAME: bounds [1, 1, 5]
89+
// CHECK-SAME: bounds [1, 1, 4]
8990
// DPP subgroup reduce replaces tree reduction
9091
// CHECK: gpu.subgroup_reduce add {{.*}} cluster(size = 2)
9192
// CHECK: arith.cmpi eq, %[[RTID]], %c0
9293
// CHECK: scf.if
9394
// CHECK: rock.lds_barrier
9495
// CHECK: rock.threadwise_read_into
95-
rock.blockwise_broadcast_reduce sum [#inputView][#inputView_tid][#inputView_iter]%input_reg into %output_reg using %ws_lds {axis = 1 : index, blockSize = 10 : i32, nrDimPerThread = 3 : index} : memref<3xf32, #gpu.address_space<private>> using memref<30xf32, #gpu.address_space<workgroup>> into memref<3xf32, #gpu.address_space<private>>
96+
rock.blockwise_broadcast_reduce sum [#inputView][#inputView_tid][#inputView_iter]%input_reg into %output_reg using %ws_lds {axis = 1 : index, blockSize = 8 : i32, nrDimPerThread = 4 : index} : memref<4xf32, #gpu.address_space<private>> using memref<32xf32, #gpu.address_space<workgroup>> into memref<4xf32, #gpu.address_space<private>>
9697
return
9798
}
9899

0 commit comments

Comments
 (0)