Skip to content

Commit 2276637

Browse files
committed
Enforce rocmlirMIGraphXAttentionCreate input contract in release builds
The previous round of C API hardening guarded the entry-point arguments with assert(), which compiles to a no-op under NDEBUG. Release-build callers that violated the contract (NULL queries / keys / values, NULL preSoftmaxElemWiseInputs with a non-zero count, negative splitKV / slidingWindowSize / numPreSoftmaxInputs, NULL resultType) would still fall through to the same confusing failure modes the assertions were meant to prevent: NULL deref while iterating the inputs array, "no Q operand" diagnostics on a half-built op, or splitKV silently dropped. Replace the asserts with `if (...) return reject(msg)` checks that fire uniformly in debug and release builds, write a "rocmlirMIGraphXAttentionCreate: <reason>" line to stderr (matching the existing `mlirMIGraphXAddBackendPipeline` diagnostic style in this file), and return a null MlirOperation. Callers can detect the failure with mlirOperationIsNull. The order of checks now also catches a negative numPreSoftmaxInputs before it is used as a loop bound, fixing a subtle order-of-evaluation hazard in the previous assert chain. Document the new contract in mlir-c/Dialect/MIGraphX.h: list every condition that returns null and clarify that all *other* invariants (operand element types, shape compatibility, feature/operand consistency) are still enforced by the AttentionOp verifier rather than by the C API. This matches what the implementation actually does and lets callers distinguish "I gave it garbage" (null return, detectable here) from "the verifier rejected my IR" (verified op returned but parsing/verification fails later). Existing CAPI tests (mlir/test/CAPI/mixr_attention.c) continue to pass unchanged because they always provide valid inputs. Made-with: Cursor
1 parent 2bc70e5 commit 2276637

2 files changed

Lines changed: 40 additions & 14 deletions

File tree

mlir/include/mlir-c/Dialect/MIGraphX.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ mlirMIGraphXAddBackendPipeline(MlirPassManager pm,
105105
/// \p prefixOffset is required when prefix_offset is set; pass null to omit.
106106
/// \p splitKV is the number of KV splits (0 or 1 = omit attribute).
107107
/// \p slidingWindowSize is the window size (0 = omit attribute).
108+
///
109+
/// Contract violations are rejected with a stderr diagnostic and a null
110+
/// MlirOperation return (check via mlirOperationIsNull). The same contract
111+
/// is enforced in both debug and release builds. Specifically the function
112+
/// returns a null op (and writes a "rocmlirMIGraphXAttentionCreate: ..."
113+
/// line to stderr) if any of \p queries, \p keys, \p values is null, if
114+
/// \p numPreSoftmaxInputs is negative or \p preSoftmaxElemWiseInputs is
115+
/// NULL when the count is positive, if \p splitKV or \p slidingWindowSize
116+
/// is negative, or if \p resultType is null. All other invariants (operand
117+
/// element types, shape compatibility, feature/operand consistency, etc.)
118+
/// are still left to the AttentionOp verifier.
108119
MLIR_CAPI_EXPORTED MlirOperation rocmlirMIGraphXAttentionCreate(
109120
MlirLocation location, MlirValue queries, MlirValue keys, MlirValue values,
110121
intptr_t numPreSoftmaxInputs, const MlirValue *preSoftmaxElemWiseInputs,

mlir/lib/CAPI/Dialect/MIGraphX.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,35 @@ MLIR_CAPI_EXPORTED MlirOperation rocmlirMIGraphXAttentionCreate(
170170
MlirType resultType, MlirType lseType, MlirType softmaxType,
171171
MlirRegion preSoftmaxBody, uint32_t features, MlirValue currentSeqLen,
172172
MlirValue prefixOffset, int32_t splitKV, int32_t slidingWindowSize) {
173-
// Reject contract violations up front. The op verifier would catch most
174-
// of these later, but the failure modes are confusing (NULL deref on
175-
// the inputs array, "no Q operand" diagnostics on the parsed op) and
176-
// splitKV < 0 used to be silently dropped.
177-
assert(!mlirValueIsNull(queries) && "queries operand is required");
178-
assert(!mlirValueIsNull(keys) && "keys operand is required");
179-
assert(!mlirValueIsNull(values) && "values operand is required");
180-
assert((numPreSoftmaxInputs == 0 || preSoftmaxElemWiseInputs != nullptr) &&
181-
"preSoftmaxElemWiseInputs array must be non-NULL when count > 0");
182-
assert(numPreSoftmaxInputs >= 0 &&
183-
"numPreSoftmaxInputs must be non-negative");
184-
assert(splitKV >= 0 && "splitKV must be non-negative (0 or 1 = omit)");
185-
assert(slidingWindowSize >= 0 && "slidingWindowSize must be non-negative");
186-
assert(!mlirTypeIsNull(resultType) && "resultType is required");
173+
// Reject contract violations up front and uniformly across debug and
174+
// release builds. The op verifier would catch most of these later, but
175+
// the failure modes are confusing in NDEBUG builds (NULL deref on the
176+
// inputs array, "no Q operand" diagnostics on a half-built op,
177+
// splitKV < 0 silently dropped) and the previous assert-only checks
178+
// compiled out in release. Returning a null MlirOperation lets callers
179+
// detect failure with mlirOperationIsNull and matches the conventions
180+
// documented in the header.
181+
auto reject = [](const char *msg) -> MlirOperation {
182+
llvm::errs() << "rocmlirMIGraphXAttentionCreate: " << msg << "\n";
183+
return MlirOperation{nullptr};
184+
};
185+
if (mlirValueIsNull(queries))
186+
return reject("queries operand is required");
187+
if (mlirValueIsNull(keys))
188+
return reject("keys operand is required");
189+
if (mlirValueIsNull(values))
190+
return reject("values operand is required");
191+
if (numPreSoftmaxInputs < 0)
192+
return reject("numPreSoftmaxInputs must be non-negative");
193+
if (numPreSoftmaxInputs > 0 && preSoftmaxElemWiseInputs == nullptr)
194+
return reject(
195+
"preSoftmaxElemWiseInputs array must be non-NULL when count > 0");
196+
if (splitKV < 0)
197+
return reject("splitKV must be non-negative (0 or 1 = omit)");
198+
if (slidingWindowSize < 0)
199+
return reject("slidingWindowSize must be non-negative");
200+
if (mlirTypeIsNull(resultType))
201+
return reject("resultType is required");
187202

188203
MlirContext ctx = mlirLocationGetContext(location);
189204
MlirOperationState state = mlirOperationStateGet(

0 commit comments

Comments
 (0)