Skip to content

Commit d5ca8d5

Browse files
address review comments: dyn_cast, StringMap, safeGlobalMemBytes
- Replace isa<> + cast<> pairs with dyn_cast<> in verifyCommonAttnGemmParameters - Replace std::unordered_map<std::string> cache with llvm::StringMap - Extract safeGlobalMemBytes helper to remove duplicated overflow-guard logic in lookupDeviceGlobalMemorySizeBytes Made-with: Cursor
1 parent 3af1e0e commit d5ca8d5

2 files changed

Lines changed: 21 additions & 23 deletions

File tree

mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,15 @@ AmdArchInfo mlir::rock::lookupArchInfo(StringRef arch) {
398398
llvm_unreachable(msg.c_str());
399399
}
400400

401+
static FailureOr<int64_t> safeGlobalMemBytes(const hipDeviceProp_t &prop) {
402+
constexpr uint64_t int64Max =
403+
static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
404+
uint64_t bytes = static_cast<uint64_t>(prop.totalGlobalMem);
405+
if (bytes > int64Max)
406+
return failure();
407+
return static_cast<int64_t>(bytes);
408+
}
409+
401410
FailureOr<int64_t>
402411
mlir::rock::lookupDeviceGlobalMemorySizeBytes(StringRef arch) {
403412
#ifdef _WIN32
@@ -410,13 +419,7 @@ mlir::rock::lookupDeviceGlobalMemorySizeBytes(StringRef arch) {
410419
hipDeviceProp_t prop;
411420
if (auto err = hipGetDeviceProperties(&prop, id); err != hipSuccess)
412421
return failure();
413-
414-
constexpr uint64_t int64Max =
415-
static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
416-
uint64_t bytes = static_cast<uint64_t>(prop.totalGlobalMem);
417-
if (bytes > int64Max)
418-
return failure();
419-
return static_cast<int64_t>(bytes);
422+
return safeGlobalMemBytes(prop);
420423
};
421424

422425
if (chip == "native")
@@ -439,14 +442,11 @@ mlir::rock::lookupDeviceGlobalMemorySizeBytes(StringRef arch) {
439442
if (deviceChip != chip)
440443
continue;
441444

442-
constexpr uint64_t int64Max =
443-
static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
444-
uint64_t bytes = static_cast<uint64_t>(prop.totalGlobalMem);
445-
if (bytes > int64Max)
445+
auto maybeBytes = safeGlobalMemBytes(prop);
446+
if (failed(maybeBytes))
446447
continue;
447448

448-
int64_t bytes64 = static_cast<int64_t>(bytes);
449-
minMatchingBytes = std::min(minMatchingBytes, bytes64);
449+
minMatchingBytes = std::min(minMatchingBytes, *maybeBytes);
450450
foundMatch = true;
451451
}
452452

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "llvm/ADT/SmallSet.h"
5151
#include "llvm/ADT/SmallVector.h"
5252
#include "llvm/ADT/StringExtras.h"
53+
#include "llvm/ADT/StringMap.h"
5354
#include "llvm/ADT/StringRef.h"
5455
#include "llvm/ADT/TypeSwitch.h"
5556
#include "llvm/Support/Debug.h"
@@ -65,8 +66,6 @@
6566
#include <iterator>
6667
#include <limits>
6768
#include <mutex>
68-
#include <string>
69-
#include <unordered_map>
7069

7170
using namespace mlir;
7271
using namespace mlir::rock;
@@ -157,15 +156,14 @@ static int64_t getSplitKVExtraStorageLimitBytes(AttentionOp op) {
157156
return maybeOverride.value();
158157

159158
StringAttr arch = rock::getArchValue(op.getOperation());
160-
std::string archKey = arch.getValue().str();
161159

162160
// Cache the computed limit per arch to avoid repeating HIP queries during
163161
// verifier runs over many attention ops.
164162
static std::mutex cacheMutex;
165-
static std::unordered_map<std::string, int64_t> cachedLimits;
163+
static llvm::StringMap<int64_t> cachedLimits;
166164
std::lock_guard<std::mutex> lock(cacheMutex);
167165

168-
auto it = cachedLimits.find(archKey);
166+
auto it = cachedLimits.find(arch.getValue());
169167
if (it != cachedLimits.end())
170168
return it->second;
171169

@@ -178,7 +176,8 @@ static int64_t getSplitKVExtraStorageLimitBytes(AttentionOp op) {
178176
std::clamp(dynamicLimit, minDynamicLimitBytes, maxDynamicLimitBytes);
179177
}
180178

181-
auto [insertedIt, inserted] = cachedLimits.emplace(archKey, limitBytes);
179+
auto [insertedIt, inserted] =
180+
cachedLimits.insert({arch.getValue(), limitBytes});
182181
(void)inserted;
183182
return insertedIt->second;
184183
}
@@ -3331,8 +3330,8 @@ static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
33313330
ShapedType oType = cast<ShapedType>(op.getOutType());
33323331
int64_t oBatchDim = oType.getShape().size() == 3 ? oType.getShape()[0] : 1;
33333332
int64_t oBatchDimOrig = oBatchDim;
3334-
if (isa<AttentionOp>(op)) {
3335-
int64_t splitKV = cast<AttentionOp>(op).getSplitKV();
3333+
if (auto attentionOp = dyn_cast<AttentionOp>(op)) {
3334+
int64_t splitKV = attentionOp.getSplitKV();
33363335
if (oBatchDim % splitKV != 0)
33373336
return op.emitError("Batch size must be divisible by splitKV");
33383337

@@ -3357,8 +3356,7 @@ static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
33573356
return op.emitError("Head dimensions do not match (V and Output)");
33583357
}
33593358

3360-
if (isa<AttentionOp>(op)) {
3361-
AttentionOp attentionOp = cast<AttentionOp>(op);
3359+
if (auto attentionOp = dyn_cast<AttentionOp>(op)) {
33623360
if (failed(
33633361
verifySplitKVExtraStorage(attentionOp, qBatchDim, queryM, valueN)))
33643362
return failure();

0 commit comments

Comments
 (0)