@@ -50,8 +50,6 @@ using namespace mlir::rock;
5050namespace mlir ::rock::hwtranspose {
5151namespace {
5252
53- bool archSupported (StringRef arch) { return arch.contains (" gfx950" ); }
54-
5553// Check if element type is supported for LDS transpose load
5654// - f16, bf16: ds_read_tr16_b64 (4 elements)
5755// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements)
@@ -78,15 +76,6 @@ static int64_t getTransposeLoadVectorLength(Type elemType) {
7876 llvm_unreachable (" Unsupported element type for LDS transpose load" );
7977}
8078
81- // Validates MFMA geometry for LDS transpose support.
82- // Supported combinations:
83- // Standard: (16,16), (16,32), (32,8), (32,16)
84- // Scaled FP8: (16,128) quad-rate, (32,64) quad-rate
85- static bool isValidMfmaGeometry (int64_t dDim, int64_t kDim ) {
86- return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128 )) ||
87- (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64 ));
88- }
89-
9079// Shape of a single MFMA instruction (internal use only).
9180struct MfmaInstrShape {
9281 int64_t mnMfma;
@@ -140,8 +129,10 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB,
140129 dec.nPerWave = nPerWave;
141130 dec.doubleBuffering = doubleBuffering;
142131
143- // Basic applicability checks
144- if (!archSupported (arch) || !DirectToLds) {
132+ // Basic applicability checks. Use the arch DB as the single source of truth
133+ // for which architectures support ds_read_tr* (kept consistent with the
134+ // verifier in RockDialect.cpp via AmdArchInfo::hasLdsTransposeLoad).
135+ if (!rock::lookupArchInfo (arch).hasLdsTransposeLoad || !DirectToLds) {
145136 return dec;
146137 }
147138
@@ -165,10 +156,28 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB,
165156 return dec;
166157
167158 // Validate MFMA geometry
168- if (!isValidMfmaGeometry (shape.mnMfma , shape.kMfma )) {
159+ if (!isValidLdsTransposeMfmaGeometry (shape.mnMfma , shape.kMfma )) {
169160 return dec;
170161 }
171162
163+ // Reject geometry/type combinations not handled in getBasePanelOffsets:
164+ // - F16/BF16 path supports: (16,16), (16,32), (32,8), (32,16)
165+ // - FP8/BF8 path supports: (16,32), (32,16), (16,128), (32,64)
166+ // Mismatched pairs would hit llvm_unreachable in getBasePanelOffsets.
167+ // typesCompatible() above already guarantees A and B are either identical
168+ // or both FP8/BF8 variants, so checking elemTypeA is sufficient.
169+ bool isQuadRateGeometry = (shape.mnMfma == 16 && shape.kMfma == 128 ) ||
170+ (shape.mnMfma == 32 && shape.kMfma == 64 );
171+ bool isF16OnlyGeometry = (shape.mnMfma == 16 && shape.kMfma == 16 ) ||
172+ (shape.mnMfma == 32 && shape.kMfma == 8 );
173+ if (isFp8Type (elemTypeA)) {
174+ if (isF16OnlyGeometry)
175+ return dec;
176+ } else {
177+ if (isQuadRateGeometry)
178+ return dec;
179+ }
180+
172181 if (!validatePaneling (shape, operand, mPerBlock , nPerBlock, kPerBlock )) {
173182 return dec;
174183 }
@@ -325,9 +334,9 @@ LDSTransposeDecision decideLDSTransposeForOperands(
325334 }
326335 // else - neither operand usable, enableA/enableB remain false.
327336
328- // Check if numWaves is supported (1, 2, 3, 4, 8, 16)
329- // TODO: support 32 waves for WMMA
330- int64_t numWaves = (mPerBlock * nPerBlock) / ( mPerWave * nPerWave);
337+ // Check if numWaves is supported (1, 2, 4, 8, 16).
338+ // TODO: support 32 waves for WMMA.
339+ int64_t numWaves = (mPerBlock / mPerWave ) * (nPerBlock / nPerWave);
331340 if (numWaves > 16 ) {
332341 result.enableA = false ;
333342 result.enableB = false ;
@@ -346,7 +355,7 @@ LDSTransposeConfigAttr buildTransposeAttrFromParams(
346355 // INVARIANT: MFMA geometry must be valid
347356 assert (mfmaDDim > 0 && mfmaKDim > 0 &&
348357 " MFMA geometry must be set when building transpose attributes" );
349- assert (isValidMfmaGeometry (mfmaDDim, mfmaKDim) &&
358+ assert (isValidLdsTransposeMfmaGeometry (mfmaDDim, mfmaKDim) &&
350359 " Invalid MFMA geometry for LDS transpose - valid: (16,16), (16,32), "
351360 " (16,128), (32,8), (32,16), (32,64)" );
352361
@@ -582,14 +591,17 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc,
582591// ===----------------------------------------------------------------------===//
583592// writePanelVectorsToDestination - Write loaded panel vectors to destination
584593//
585- // Extracts individual f16/bf16 elements from loaded panel vectors and writes
586- // them sequentially to the destination buffer. Each panel vector contains 4
587- // elements (ds_read_tr16_b64 always returns vector<4xf16>).
594+ // Extracts individual elements from loaded panel vectors and writes them
595+ // sequentially to the destination buffer. Panel vector width depends on the
596+ // element type:
597+ // - f16/bf16: vector<4> (ds_read_tr16_b64)
598+ // - fp8/bf8: vector<8> (ds_read_tr8_b64)
588599//
589600// Parameters:
590601// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16,
591- // vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar
592- // layout) targetElems - Maximum number of elements to write
602+ // vector<8> for fp8/bf8)
603+ // dest - Destination memref (rank-1, scalar layout)
604+ // targetElems - Maximum number of elements to write
593605//
594606// Returns:
595607// success() if all target elements were written
@@ -655,13 +667,17 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc,
655667// ===----------------------------------------------------------------------===//
656668// getBasePanelOffsets - Compute per-panel LDS offsets for a given lane ID
657669//
658- // Given a wavefront lane ID and a specific MFMA layout (L16x32, L16x16, etc.),
659- // this function computes the base byte offsets into LDS memory where each
660- // lane should read its operands from.
670+ // Given a wavefront lane ID and a specific MFMA layout, this function computes
671+ // the base byte offsets into LDS memory where each lane should read its
672+ // operands from. These offsets are derived from AMD's LDS tiling and MFMA
673+ // operand layout conventions, mapping each lane's register to the correct
674+ // element position in LDS.
661675//
662- // These offsets are derived from AMD's LDS tiling and MFMA operand layout
663- // conventions (e.g., 16x16, 16x32 panels). The goal is to map each lane's
664- // register to the correct element position in LDS.
676+ // Supported (dDim, kDim) combinations per element type:
677+ // F16 / BF16: (16,16), (16,32), (32,8), (32,16) -- ds_read_tr16_b64
678+ // FP8 / BF8: (16,32), (32,16), (16,128), (32,64) -- ds_read_tr8_b64
679+ // Any other (type, geometry) combination triggers llvm_unreachable. Callers
680+ // must validate the (type, geometry) pair upstream (see makeDecision()).
665681//
666682// Note: This is an internal helper function. Use computeLDSBaseOffsets()
667683// instead for better readability.
@@ -809,7 +825,7 @@ static SmallVector<Value> getBasePanelOffsets(PatternRewriter &b, Location loc,
809825//
810826// Parameters:
811827// dDim - MFMA D dimension (M or N, 16 or 32)
812- // kDim - MFMA K dimension (8, 16, or 32 )
828+ // kDim - MFMA K dimension (8, 16, 32, 64, or 128 )
813829// lane - Thread's lane ID within the workgroup
814830// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping
815831//
@@ -841,17 +857,20 @@ static std::pair<Value, Value> computeLDSBaseOffsets(PatternRewriter &b,
841857// dimensions, and decomposes the wave ID into a 2D grid position.
842858//
843859// This version uses a deterministic layout selection based solely on the number
844- // of physical waves (1, 2, 3, or 4) . The goal is to match the wave grid to the
845- // number of available wave tiles (waveTilesInM, waveTilesInN) while choosing a
846- // stable and predictable layout.
860+ // of physical waves. The goal is to match the wave grid to the number of
861+ // available wave tiles (waveTilesInM, waveTilesInN) while choosing a stable
862+ // and predictable layout.
847863//
848864// Key principles:
849- // - physicalWaves ∈ {1, 2, 3, 4} (corresponding to 64–256 threads)
865+ // - physicalWaves ∈ {1, 2, 4, 8, 16}. Tuning generates only power-of-2
866+ // wave-tile factors (see computeDPerWave's `factor *= 2` step), so
867+ // numWaves is always a product of two power-of-2 values.
850868// - Prefer balanced or natural layouts when possible:
851869// 1 wave → 1×1
852870// 2 waves → prefer 1×2
853- // 3 waves → prefer 1×3
854871// 4 waves → prefer 2×2
872+ // 8 waves → prefer 2×4
873+ // 16 waves → prefer 4×4
855874// - If a preferred layout does not fit the available tiles, fallback logic
856875// selects the best possible layout while maintaining determinism.
857876// - The result defines which spatial tile each wave is responsible for,
@@ -908,21 +927,6 @@ computeWaveGridLayout(PatternRewriter &b, Location loc, Value waveId,
908927 }
909928 break ;
910929
911- case 3 :
912- // Three waves: prefer 1×3, fallback to 3×1 or dimension-based
913- if (waveTilesInN >= 3 ) {
914- wavesInM = 1 ;
915- wavesInN = 3 ;
916- } else if (waveTilesInM >= 3 ) {
917- wavesInM = 3 ;
918- wavesInN = 1 ;
919- } else {
920- // Fallback: choose dimension with more tiles (outer loop handles rest)
921- wavesInM = (waveTilesInN >= waveTilesInM) ? 1 : 3 ;
922- wavesInN = (waveTilesInN >= waveTilesInM) ? 3 : 1 ;
923- }
924- break ;
925-
926930 case 4 :
927931 // Four waves: prefer 2×2 (balanced), then 1×4, 4×1, or fallback
928932 if (waveTilesInM >= 2 && waveTilesInN >= 2 ) {
@@ -1286,12 +1290,13 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b,
12861290 OperandKind operand =
12871291 config.getIsOperandA () ? OperandKind::A : OperandKind::B;
12881292
1289- // Compute wave grid layout and decompose wave ID into 2D position
1293+ // Compute wave grid layout and decompose wave ID into 2D position.
12901294 FailureOr<WaveGridLayout> maybeWaveGrid = computeWaveGridLayout (
12911295 b, loc, waveId, mPerWave , nPerWave, mPerBlock , nPerBlock);
1292- assert (succeeded (maybeWaveGrid) &&
1293- " If we decided to use transpose load, this must work" );
1294- WaveGridLayout waveGrid = maybeWaveGrid.value ();
1296+ if (failed (maybeWaveGrid))
1297+ return op.emitOpError (
1298+ " unsupported wave grid layout for LDS transpose load" );
1299+ WaveGridLayout waveGrid = *maybeWaveGrid;
12951300 Value waveM = waveGrid.waveM ;
12961301 Value waveN = waveGrid.waveN ;
12971302
@@ -1450,8 +1455,13 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b,
14501455 << expectedLoads << " , got " << panelVectors.size ();
14511456 }
14521457
1453- // Write loaded panel vectors to destination buffer
1454- // Destination is rank-1 with scalar sequential layout
1458+ // Write loaded panel vectors to destination buffer.
1459+ // Destination must be rank-1 with a static shape; we cannot statically
1460+ // size the writes otherwise.
1461+ if (destType.getRank () != 1 || destType.isDynamicDim (0 )) {
1462+ return op.emitOpError (
1463+ " LDS transpose load destination must be rank-1 with a static shape" );
1464+ }
14551465 int64_t destCap = destType.getShape ()[0 ];
14561466 int64_t targetElems = std::min<int64_t >(sliceElems, destCap);
14571467
0 commit comments