Skip to content

Commit 519a211

Browse files
Tighten the rock.lds_transpose_load ODS result-type constraint to the
exact set of valid vector types and drop the now-redundant verifier checks. Add ldsTransposeConfig structural checks to ThreadwiseReadIntoOp::verify (rank-1 + static dest, supported element type, (geometry, type) consistency). Replace the assert + .value() in emitThreadwiseHWTranspose with emitOpError to avoid UB in release builds and reject non-rank-1 / dynamic destinations up-front. Use AmdArchInfo::hasLdsTransposeLoad for arch gating, share a single isValidLdsTransposeMfmaGeometry helper, align the numWaves formula with computeWaveGridLayout, drop the dead (tuning only emits power-of-2 wave-tile factors), and refresh doc comments. Add four negative ODS-coverage tests for the result-type constraint.
1 parent 1a3bbab commit 519a211

5 files changed

Lines changed: 162 additions & 97 deletions

File tree

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,9 @@ def Rock_LDSTransposeLoadOp
12321232
Arguments<(ins Arg<MemRefOf<[F16, BF16, F8E4M3FN, F8E5M2]>,
12331233
"LDS source buffer">:$source,
12341234
Variadic<Index>:$indices)>,
1235-
Results<(outs AnyVectorOfNonZeroRank:$result)> {
1235+
Results<(outs AnyTypeOf<
1236+
[VectorOfLengthAndType<[4], [F16, BF16]>,
1237+
VectorOfLengthAndType<[8], [F8E4M3FN, F8E5M2]>]>:$result)> {
12361238
let summary =
12371239
"Hardware-assisted LDS transpose load for matrix accelerator tile";
12381240
let description = [{

mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ namespace mlir::rock::hwtranspose {
4949
// Operand selector (A or B matrix)
5050
enum class OperandKind { A, B };
5151

52+
// Returns true if the given (D, K) MFMA geometry is one of the geometries
53+
// recognized by the LDS transpose load lowering.
54+
// Recognized combinations:
55+
// Standard: (16,16), (16,32), (32,8), (32,16)
56+
// Scaled FP8: (16,128) quad-rate, (32,64) quad-rate
57+
// Note: this is geometry-only recognition. Element-type compatibility
58+
// (e.g., FP8-only quad-rate) is enforced separately by the caller.
59+
inline bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) {
60+
return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) ||
61+
(dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64));
62+
}
63+
5264
// Build LDS transpose config attribute from already-computed MFMA params.
5365
// Used in BlockwiseLoadTileToThreadwise when decision was made upstream.
5466
// Requires mfmaDDim > 0 and mfmaKDim > 0 (asserted).

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
2121
#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
2222
#include "mlir/Dialect/Rock/IR/GetRockInfo.h"
23+
#include "mlir/Dialect/Rock/utility/LdsTransposeLoad.h"
2324
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
2425
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2526
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -558,22 +559,14 @@ LogicalResult TransformMapAttr::verify(
558559
return success();
559560
}
560561

561-
// Helper function to check valid MFMA geometry for LDS transpose
562-
static bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) {
563-
// Supported geometries:
564-
// Standard: (16,16), (16,32), (32,8), (32,16)
565-
// Scaled FP8: (16,128), (32,64)
566-
return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) ||
567-
(dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64));
568-
}
569-
570562
LogicalResult LDSTransposeConfigAttr::verify(
571563
function_ref<InFlightDiagnostic()> emitError, int64_t dDim, int64_t kDim,
572564
int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave,
573565
int64_t nPerWave, bool doubleBuffering, bool isOperandA) {
574566

575-
// Validate MFMA geometry
576-
if (!isValidLdsTransposeMfmaGeometry(dDim, kDim)) {
567+
// Validate MFMA geometry (geometry-only recognition; type-aware checks
568+
// are enforced by the lowering decision in LdsTransposeLoad.cpp).
569+
if (!hwtranspose::isValidLdsTransposeMfmaGeometry(dDim, kDim)) {
577570
return emitError()
578571
<< "invalid MFMA geometry (" << dDim << "x" << kDim
579572
<< ") for LDS transpose - valid combinations: "
@@ -2154,42 +2147,17 @@ LogicalResult LDSTransposeLoadOp::verify() {
21542147
if (!memSpaceCheck.value())
21552148
return emitOpError("source memory address space must be workgroup (LDS)");
21562149

2157-
// Result element type must match source element type
2150+
// Result element type must match source element type. ODS guarantees the
2151+
// result is one of the allowed vector types, so cast<VectorType> is safe.
21582152
Type srcElemType = srcType.getElementType();
2159-
VectorType resultType = getResult().getType();
2160-
Type resultElemType = resultType.getElementType();
2161-
2153+
Type resultElemType =
2154+
cast<VectorType>(getResult().getType()).getElementType();
21622155
if (resultElemType != srcElemType) {
21632156
return emitOpError("result element type (")
21642157
<< resultElemType << ") must match source element type ("
21652158
<< srcElemType << ")";
21662159
}
21672160

2168-
if (resultType.getRank() != 1)
2169-
return emitOpError("expected 1-D result vector, but got rank ")
2170-
<< resultType.getRank();
2171-
2172-
// Verify result vector length based on element type:
2173-
// - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements
2174-
// - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64
2175-
// returns 8 elements
2176-
int64_t expectedVecLen;
2177-
if (srcElemType.isF16() || srcElemType.isBF16()) {
2178-
expectedVecLen = 4;
2179-
} else if (isa<Float8E4M3FNType>(srcElemType) ||
2180-
isa<Float8E5M2Type>(srcElemType)) {
2181-
expectedVecLen = 8;
2182-
} else {
2183-
return emitOpError("unsupported element type for LDS transpose load: ")
2184-
<< srcElemType;
2185-
}
2186-
2187-
if (resultType.getNumElements() != expectedVecLen) {
2188-
return emitOpError("expected result vector of ")
2189-
<< expectedVecLen << " elements for " << srcElemType
2190-
<< " type, but got " << resultType.getNumElements();
2191-
}
2192-
21932161
// Check hardware support using AmdArchDb
21942162
StringRef arch = rock::getArchValue(*this);
21952163
AmdArchInfo archInfo = rock::lookupArchInfo(arch);
@@ -2373,6 +2341,31 @@ LogicalResult ThreadwiseReadIntoOp::verify() {
23732341
"in register-to-register reads produced by input fusion");
23742342
}
23752343
}
2344+
2345+
// Structural checks for the LDS transpose load fast path.
2346+
if (LDSTransposeConfigAttr cfg = getLdsTransposeConfigAttr()) {
2347+
if (destType.getRank() != 1 || destType.isDynamicDim(0))
2348+
return emitOpError("ldsTransposeConfig requires a rank-1 destination "
2349+
"with a static shape");
2350+
Type destElemType = destType.getElementType();
2351+
bool isFp8 = isa<Float8E4M3FNType, Float8E5M2Type>(destElemType);
2352+
bool is16Bit = destElemType.isF16() || destElemType.isBF16();
2353+
if (!is16Bit && !isFp8)
2354+
return emitOpError("ldsTransposeConfig only supports f16, bf16, "
2355+
"f8E4M3FN, or f8E5M2 destination element types");
2356+
int64_t dDim = cfg.getDDim();
2357+
int64_t kDim = cfg.getKDim();
2358+
bool isQuadRateGeometry =
2359+
(dDim == 16 && kDim == 128) || (dDim == 32 && kDim == 64);
2360+
bool isF16OnlyGeometry =
2361+
(dDim == 16 && kDim == 16) || (dDim == 32 && kDim == 8);
2362+
if (isFp8 && isF16OnlyGeometry)
2363+
return emitOpError("MFMA geometry (")
2364+
<< dDim << "x" << kDim << ") is not supported for FP8/BF8";
2365+
if (is16Bit && isQuadRateGeometry)
2366+
return emitOpError("quad-rate MFMA geometry (")
2367+
<< dDim << "x" << kDim << ") is only valid for FP8/BF8";
2368+
}
23762369
return success();
23772370
}
23782371

mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ using namespace mlir::rock;
5050
namespace mlir::rock::hwtranspose {
5151
namespace {
5252

53-
bool archSupported(StringRef arch) { return arch.contains("gfx950"); }
54-
5553
// Check if element type is supported for LDS transpose load
5654
// - f16, bf16: ds_read_tr16_b64 (4 elements)
5755
// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements)
@@ -78,15 +76,6 @@ static int64_t getTransposeLoadVectorLength(Type elemType) {
7876
llvm_unreachable("Unsupported element type for LDS transpose load");
7977
}
8078

81-
// Validates MFMA geometry for LDS transpose support.
82-
// Supported combinations:
83-
// Standard: (16,16), (16,32), (32,8), (32,16)
84-
// Scaled FP8: (16,128) quad-rate, (32,64) quad-rate
85-
static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) {
86-
return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) ||
87-
(dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64));
88-
}
89-
9079
// Shape of a single MFMA instruction (internal use only).
9180
struct MfmaInstrShape {
9281
int64_t mnMfma;
@@ -140,8 +129,10 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB,
140129
dec.nPerWave = nPerWave;
141130
dec.doubleBuffering = doubleBuffering;
142131

143-
// Basic applicability checks
144-
if (!archSupported(arch) || !DirectToLds) {
132+
// Basic applicability checks. Use the arch DB as the single source of truth
133+
// for which architectures support ds_read_tr* (kept consistent with the
134+
// verifier in RockDialect.cpp via AmdArchInfo::hasLdsTransposeLoad).
135+
if (!rock::lookupArchInfo(arch).hasLdsTransposeLoad || !DirectToLds) {
145136
return dec;
146137
}
147138

@@ -165,10 +156,28 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB,
165156
return dec;
166157

167158
// Validate MFMA geometry
168-
if (!isValidMfmaGeometry(shape.mnMfma, shape.kMfma)) {
159+
if (!isValidLdsTransposeMfmaGeometry(shape.mnMfma, shape.kMfma)) {
169160
return dec;
170161
}
171162

163+
// Reject geometry/type combinations not handled in getBasePanelOffsets:
164+
// - F16/BF16 path supports: (16,16), (16,32), (32,8), (32,16)
165+
// - FP8/BF8 path supports: (16,32), (32,16), (16,128), (32,64)
166+
// Mismatched pairs would hit llvm_unreachable in getBasePanelOffsets.
167+
// typesCompatible() above already guarantees A and B are either identical
168+
// or both FP8/BF8 variants, so checking elemTypeA is sufficient.
169+
bool isQuadRateGeometry = (shape.mnMfma == 16 && shape.kMfma == 128) ||
170+
(shape.mnMfma == 32 && shape.kMfma == 64);
171+
bool isF16OnlyGeometry = (shape.mnMfma == 16 && shape.kMfma == 16) ||
172+
(shape.mnMfma == 32 && shape.kMfma == 8);
173+
if (isFp8Type(elemTypeA)) {
174+
if (isF16OnlyGeometry)
175+
return dec;
176+
} else {
177+
if (isQuadRateGeometry)
178+
return dec;
179+
}
180+
172181
if (!validatePaneling(shape, operand, mPerBlock, nPerBlock, kPerBlock)) {
173182
return dec;
174183
}
@@ -325,9 +334,9 @@ LDSTransposeDecision decideLDSTransposeForOperands(
325334
}
326335
// else - neither operand usable, enableA/enableB remain false.
327336

328-
// Check if numWaves is supported (1, 2, 3, 4, 8, 16)
329-
// TODO: support 32 waves for WMMA
330-
int64_t numWaves = (mPerBlock * nPerBlock) / (mPerWave * nPerWave);
337+
// Check if numWaves is supported (1, 2, 4, 8, 16).
338+
// TODO: support 32 waves for WMMA.
339+
int64_t numWaves = (mPerBlock / mPerWave) * (nPerBlock / nPerWave);
331340
if (numWaves > 16) {
332341
result.enableA = false;
333342
result.enableB = false;
@@ -346,7 +355,7 @@ LDSTransposeConfigAttr buildTransposeAttrFromParams(
346355
// INVARIANT: MFMA geometry must be valid
347356
assert(mfmaDDim > 0 && mfmaKDim > 0 &&
348357
"MFMA geometry must be set when building transpose attributes");
349-
assert(isValidMfmaGeometry(mfmaDDim, mfmaKDim) &&
358+
assert(isValidLdsTransposeMfmaGeometry(mfmaDDim, mfmaKDim) &&
350359
"Invalid MFMA geometry for LDS transpose - valid: (16,16), (16,32), "
351360
"(16,128), (32,8), (32,16), (32,64)");
352361

@@ -582,14 +591,17 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc,
582591
//===----------------------------------------------------------------------===//
583592
// writePanelVectorsToDestination - Write loaded panel vectors to destination
584593
//
585-
// Extracts individual f16/bf16 elements from loaded panel vectors and writes
586-
// them sequentially to the destination buffer. Each panel vector contains 4
587-
// elements (ds_read_tr16_b64 always returns vector<4xf16>).
594+
// Extracts individual elements from loaded panel vectors and writes them
595+
// sequentially to the destination buffer. Panel vector width depends on the
596+
// element type:
597+
// - f16/bf16: vector<4> (ds_read_tr16_b64)
598+
// - fp8/bf8: vector<8> (ds_read_tr8_b64)
588599
//
589600
// Parameters:
590601
// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16,
591-
// vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar
592-
// layout) targetElems - Maximum number of elements to write
602+
// vector<8> for fp8/bf8)
603+
// dest - Destination memref (rank-1, scalar layout)
604+
// targetElems - Maximum number of elements to write
593605
//
594606
// Returns:
595607
// success() if all target elements were written
@@ -655,13 +667,17 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc,
655667
//===----------------------------------------------------------------------===//
656668
// getBasePanelOffsets - Compute per-panel LDS offsets for a given lane ID
657669
//
658-
// Given a wavefront lane ID and a specific MFMA layout (L16x32, L16x16, etc.),
659-
// this function computes the base byte offsets into LDS memory where each
660-
// lane should read its operands from.
670+
// Given a wavefront lane ID and a specific MFMA layout, this function computes
671+
// the base byte offsets into LDS memory where each lane should read its
672+
// operands from. These offsets are derived from AMD's LDS tiling and MFMA
673+
// operand layout conventions, mapping each lane's register to the correct
674+
// element position in LDS.
661675
//
662-
// These offsets are derived from AMD's LDS tiling and MFMA operand layout
663-
// conventions (e.g., 16x16, 16x32 panels). The goal is to map each lane's
664-
// register to the correct element position in LDS.
676+
// Supported (dDim, kDim) combinations per element type:
677+
// F16 / BF16: (16,16), (16,32), (32,8), (32,16) -- ds_read_tr16_b64
678+
// FP8 / BF8: (16,32), (32,16), (16,128), (32,64) -- ds_read_tr8_b64
679+
// Any other (type, geometry) combination triggers llvm_unreachable. Callers
680+
// must validate the (type, geometry) pair upstream (see makeDecision()).
665681
//
666682
// Note: This is an internal helper function. Use computeLDSBaseOffsets()
667683
// instead for better readability.
@@ -809,7 +825,7 @@ static SmallVector<Value> getBasePanelOffsets(PatternRewriter &b, Location loc,
809825
//
810826
// Parameters:
811827
// dDim - MFMA D dimension (M or N, 16 or 32)
812-
// kDim - MFMA K dimension (8, 16, or 32)
828+
// kDim - MFMA K dimension (8, 16, 32, 64, or 128)
813829
// lane - Thread's lane ID within the workgroup
814830
// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping
815831
//
@@ -841,17 +857,20 @@ static std::pair<Value, Value> computeLDSBaseOffsets(PatternRewriter &b,
841857
// dimensions, and decomposes the wave ID into a 2D grid position.
842858
//
843859
// This version uses a deterministic layout selection based solely on the number
844-
// of physical waves (1, 2, 3, or 4). The goal is to match the wave grid to the
845-
// number of available wave tiles (waveTilesInM, waveTilesInN) while choosing a
846-
// stable and predictable layout.
860+
// of physical waves. The goal is to match the wave grid to the number of
861+
// available wave tiles (waveTilesInM, waveTilesInN) while choosing a stable
862+
// and predictable layout.
847863
//
848864
// Key principles:
849-
// - physicalWaves ∈ {1, 2, 3, 4} (corresponding to 64–256 threads)
865+
// - physicalWaves ∈ {1, 2, 4, 8, 16}. Tuning generates only power-of-2
866+
// wave-tile factors (see computeDPerWave's `factor *= 2` step), so
867+
// numWaves is always a product of two power-of-2 values.
850868
// - Prefer balanced or natural layouts when possible:
851869
// 1 wave → 1×1
852870
// 2 waves → prefer 1×2
853-
// 3 waves → prefer 1×3
854871
// 4 waves → prefer 2×2
872+
// 8 waves → prefer 2×4
873+
// 16 waves → prefer 4×4
855874
// - If a preferred layout does not fit the available tiles, fallback logic
856875
// selects the best possible layout while maintaining determinism.
857876
// - The result defines which spatial tile each wave is responsible for,
@@ -908,21 +927,6 @@ computeWaveGridLayout(PatternRewriter &b, Location loc, Value waveId,
908927
}
909928
break;
910929

911-
case 3:
912-
// Three waves: prefer 1×3, fallback to 3×1 or dimension-based
913-
if (waveTilesInN >= 3) {
914-
wavesInM = 1;
915-
wavesInN = 3;
916-
} else if (waveTilesInM >= 3) {
917-
wavesInM = 3;
918-
wavesInN = 1;
919-
} else {
920-
// Fallback: choose dimension with more tiles (outer loop handles rest)
921-
wavesInM = (waveTilesInN >= waveTilesInM) ? 1 : 3;
922-
wavesInN = (waveTilesInN >= waveTilesInM) ? 3 : 1;
923-
}
924-
break;
925-
926930
case 4:
927931
// Four waves: prefer 2×2 (balanced), then 1×4, 4×1, or fallback
928932
if (waveTilesInM >= 2 && waveTilesInN >= 2) {
@@ -1286,12 +1290,13 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b,
12861290
OperandKind operand =
12871291
config.getIsOperandA() ? OperandKind::A : OperandKind::B;
12881292

1289-
// Compute wave grid layout and decompose wave ID into 2D position
1293+
// Compute wave grid layout and decompose wave ID into 2D position.
12901294
FailureOr<WaveGridLayout> maybeWaveGrid = computeWaveGridLayout(
12911295
b, loc, waveId, mPerWave, nPerWave, mPerBlock, nPerBlock);
1292-
assert(succeeded(maybeWaveGrid) &&
1293-
"If we decided to use transpose load, this must work");
1294-
WaveGridLayout waveGrid = maybeWaveGrid.value();
1296+
if (failed(maybeWaveGrid))
1297+
return op.emitOpError(
1298+
"unsupported wave grid layout for LDS transpose load");
1299+
WaveGridLayout waveGrid = *maybeWaveGrid;
12951300
Value waveM = waveGrid.waveM;
12961301
Value waveN = waveGrid.waveN;
12971302

@@ -1450,8 +1455,13 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b,
14501455
<< expectedLoads << ", got " << panelVectors.size();
14511456
}
14521457

1453-
// Write loaded panel vectors to destination buffer
1454-
// Destination is rank-1 with scalar sequential layout
1458+
// Write loaded panel vectors to destination buffer.
1459+
// Destination must be rank-1 with a static shape; we cannot statically
1460+
// size the writes otherwise.
1461+
if (destType.getRank() != 1 || destType.isDynamicDim(0)) {
1462+
return op.emitOpError(
1463+
"LDS transpose load destination must be rank-1 with a static shape");
1464+
}
14551465
int64_t destCap = destType.getShape()[0];
14561466
int64_t targetElems = std::min<int64_t>(sliceElems, destCap);
14571467

0 commit comments

Comments
 (0)