Skip to content

Commit 7323bea

Browse files
Harden rocmlirMIGraphXAttentionCreate input contract
The C API helper used to silently accept null Q/K/V (producing "missing operand" diagnostics on the parsed op), null preSoftmaxElemWiseInputs with count > 0 (NULL deref), and negative splitKV / slidingWindowSize (silently dropped instead of attached). Add asserts at the top of the helper for the documented contract - the header doc already says inputs are required and that splitKV is "0 or 1 = omit", so this just enforces what's already promised. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 3a70f7c commit 7323bea

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

mlir/lib/CAPI/Dialect/MIGraphX.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,21 @@ MLIR_CAPI_EXPORTED MlirOperation rocmlirMIGraphXAttentionCreate(
170170
MlirType resultType, MlirType lseType, MlirType softmaxType,
171171
MlirRegion preSoftmaxBody, uint32_t features, MlirValue currentSeqLen,
172172
MlirValue prefixOffset, int32_t splitKV, int32_t slidingWindowSize) {
173+
// Reject contract violations up front. The op verifier would catch most
174+
// of these later, but the failure modes are confusing (NULL deref on
175+
// the inputs array, "no Q operand" diagnostics on the parsed op) and
176+
// splitKV < 0 used to be silently dropped.
177+
assert(!mlirValueIsNull(queries) && "queries operand is required");
178+
assert(!mlirValueIsNull(keys) && "keys operand is required");
179+
assert(!mlirValueIsNull(values) && "values operand is required");
180+
assert((numPreSoftmaxInputs == 0 || preSoftmaxElemWiseInputs != nullptr) &&
181+
"preSoftmaxElemWiseInputs array must be non-NULL when count > 0");
182+
assert(numPreSoftmaxInputs >= 0 &&
183+
"numPreSoftmaxInputs must be non-negative");
184+
assert(splitKV >= 0 && "splitKV must be non-negative (0 or 1 = omit)");
185+
assert(slidingWindowSize >= 0 && "slidingWindowSize must be non-negative");
186+
assert(!mlirTypeIsNull(resultType) && "resultType is required");
187+
173188
MlirContext ctx = mlirLocationGetContext(location);
174189
MlirOperationState state = mlirOperationStateGet(
175190
mlirStringRefCreateFromCString("migraphx.attention"), location);

0 commit comments

Comments
 (0)