Skip to content

Commit dc38c7f

Browse files
format changes
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
1 parent 2f55a3d commit dc38c7f

2 files changed

Lines changed: 16 additions & 14 deletions

File tree

mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,8 @@ AmdArchInfo mlir::rock::lookupArchInfo(StringRef arch) {
398398
llvm_unreachable(msg.c_str());
399399
}
400400

401-
FailureOr<int64_t> mlir::rock::lookupDeviceGlobalMemorySizeBytes(StringRef arch) {
401+
FailureOr<int64_t>
402+
mlir::rock::lookupDeviceGlobalMemorySizeBytes(StringRef arch) {
402403
#ifdef _WIN32
403404
(void)arch;
404405
return failure();
@@ -422,8 +423,8 @@ FailureOr<int64_t> mlir::rock::lookupDeviceGlobalMemorySizeBytes(StringRef arch)
422423
return getMemoryForDevice(deviceId);
423424

424425
int deviceCount = 0;
425-
if (auto err = hipGetDeviceCount(&deviceCount); err != hipSuccess ||
426-
deviceCount <= 0)
426+
if (auto err = hipGetDeviceCount(&deviceCount);
427+
err != hipSuccess || deviceCount <= 0)
427428
return failure();
428429

429430
int64_t minMatchingBytes = std::numeric_limits<int64_t>::max();

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ static FailureOr<int64_t> getTypeSizeInBytes(Type type) {
128128
}
129129

130130
static FailureOr<int64_t> getSplitKVExtraStorageLimitOverrideBytes() {
131-
const char *envVar =
132-
std::getenv("ROCMLIR_ATTENTION_SPLITKV_MAX_EXTRA_BYTES");
131+
const char *envVar = std::getenv("ROCMLIR_ATTENTION_SPLITKV_MAX_EXTRA_BYTES");
133132
if (!envVar || *envVar == '\0')
134133
return failure();
135134

@@ -155,7 +154,8 @@ static int64_t getSplitKVExtraStorageLimitBytes(AttentionOp op) {
155154
return maybeOverride.value();
156155

157156
StringAttr arch = rock::getArchValue(op.getOperation());
158-
auto maybeDeviceBytes = rock::lookupDeviceGlobalMemorySizeBytes(arch.getValue());
157+
auto maybeDeviceBytes =
158+
rock::lookupDeviceGlobalMemorySizeBytes(arch.getValue());
159159
if (failed(maybeDeviceBytes) || maybeDeviceBytes.value() <= 0)
160160
return defaultLimitBytes;
161161

@@ -189,9 +189,9 @@ static LogicalResult verifySplitKVExtraStorage(AttentionOp op,
189189
return success();
190190

191191
auto maybeBaseElems = checkedMul(batchHeads, seqLenQ);
192-
auto maybeBaseOutElems =
193-
succeeded(maybeBaseElems) ? checkedMul(*maybeBaseElems, headDimV)
194-
: FailureOr<int64_t>(failure());
192+
auto maybeBaseOutElems = succeeded(maybeBaseElems)
193+
? checkedMul(*maybeBaseElems, headDimV)
194+
: FailureOr<int64_t>(failure());
195195
if (failed(maybeBaseElems) || failed(maybeBaseOutElems))
196196
return op.emitError("splitKV storage estimate overflowed");
197197

@@ -203,16 +203,17 @@ static LogicalResult verifySplitKVExtraStorage(AttentionOp op,
203203
if (succeeded(maybeExtraOutBytes))
204204
maybeExtraOutBytes = checkedMul(*maybeExtraOutBytes, *maybeOutElemBytes);
205205

206-
auto maybeExtraLseBytes =
207-
succeeded(maybeBaseElems) ? checkedMul(*maybeBaseElems, extraSplitFactor)
206+
auto maybeExtraLseBytes = succeeded(maybeBaseElems)
207+
? checkedMul(*maybeBaseElems, extraSplitFactor)
208208
: FailureOr<int64_t>(failure());
209209
if (succeeded(maybeExtraLseBytes))
210210
maybeExtraLseBytes = checkedMul(*maybeExtraLseBytes, *maybeLseElemBytes);
211211

212212
if (failed(maybeExtraOutBytes) || failed(maybeExtraLseBytes))
213213
return op.emitError("splitKV storage estimate overflowed");
214214

215-
auto maybeTotalExtraBytes = checkedAdd(*maybeExtraOutBytes, *maybeExtraLseBytes);
215+
auto maybeTotalExtraBytes =
216+
checkedAdd(*maybeExtraOutBytes, *maybeExtraLseBytes);
216217
if (failed(maybeTotalExtraBytes))
217218
return op.emitError("splitKV storage estimate overflowed");
218219

@@ -3335,8 +3336,8 @@ static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
33353336

33363337
if (isa<AttentionOp>(op)) {
33373338
AttentionOp attentionOp = cast<AttentionOp>(op);
3338-
if (failed(verifySplitKVExtraStorage(attentionOp, qBatchDim, queryM,
3339-
valueN)))
3339+
if (failed(
3340+
verifySplitKVExtraStorage(attentionOp, qBatchDim, queryM, valueN)))
33403341
return failure();
33413342
}
33423343

0 commit comments

Comments
 (0)