Skip to content

Introduce migraphx.attention op in MIGraphX Dialect#2316

Draft
umangyadav wants to merge 70 commits intodevelopfrom
migraphx_attention
Draft

Introduce migraphx.attention op in MIGraphX Dialect#2316
umangyadav wants to merge 70 commits intodevelopfrom
migraphx_attention

Conversation

@umangyadav
Copy link
Copy Markdown
Member

@umangyadav umangyadav commented Mar 27, 2026

Motivation

Currently MIGraphX passes series of decomposed ops for attention. Then inside TosaToRock, rocMLIR pattern matches to find out attention and what kind of attention variant it is.

It is increasingly becoming more difficult on rocMLIR side to do pattern matching. It is much better for migraphx graph compiler to do this.

Therefore this PR introduces migraphx.attention op where it takes "features" attribute to describe what kind of attention variant it is.

Technical Details

  • Adds migraphx.attention op which has similar semantics as rock.attention op.
  • For host compilation, it decomposes migraphx.attention op which can then get lowered to linalg
  • For the GPU it directly lowers to rock.attention op
  • adds migraphx.yield op for preSoftmaxBody
  • Adds missing trait of "elementwise" on migraphx elementwise ops.
  • Adds a utility method in CAPI to construct attention op. (Subject to change)

Test Plan

  • adds some initial E2E tests for this
  • Convert all existing E2E tests into equivelant migraphx.attention to make sure functionality is preservered with accuracy.
  • Integrate with migraphx

Merge plan :

I plan to break down this large PR into several smaller ones but keeping this large draft for feedback on overall structure.

  • Add elementwise trait
  • Add migraphx.yield operator
  • Add migraphx.attention op
  • Lowering to CPU and GPU paths with initial E2E tests
  • Convert all existing E2E tests to use this new migraphx.attention op
  • Add CAPI utilities and tests

Note to MIGraphX folks

Looks at files in mlir/test/fusion/pr-e2e/migraphx-attention/ to see examples of attention kernels.

e.g. attention kernel with sliding window + kvcache + casual masking looks like following

module {
  func.func private @mlir_attention(%arg0: !migraphx.shaped<1x2x1x2xf16, 4x2x2x1>,
                                     %arg1: !migraphx.shaped<1x2x2x8xf16, 32x16x8x1>,
                                     %arg2: !migraphx.shaped<1x2x8x2xf16, 32x16x2x1>,
                                     %arg3: !migraphx.shaped<1x2xi32, 2x1>)
                                     -> !migraphx.shaped<1x2x1x2xf16, 4x2x2x1> {
    %0 = migraphx.attention %arg0, %arg1, %arg2
      current_seq_len(%arg3 : !migraphx.shaped<1x2xi32, 2x1>) {
      } features = "kvcache|causal|sliding_window" slidingWindowSize = 4
      : <1x2x1x2xf16, 4x2x2x1>, <1x2x2x8xf16, 32x16x8x1>, <1x2x8x2xf16, 32x16x2x1>
      -> <1x2x1x2xf16, 4x2x2x1>
    return %0 : !migraphx.shaped<1x2x1x2xf16, 4x2x2x1>
  }
}

SplitKV =2 with LSE and "preSoftmaxBody"

  func.func private @mlir_attention(%arg0: !migraphx.shaped<1x4x64x64xf16, 16384x4096x64x1>,
                                     %arg1: !migraphx.shaped<1x4x64x128xf16, 32768x8192x128x1>,
                                     %arg2: !migraphx.shaped<1x4x128x64xf16, 32768x8192x64x1>,
                                     %arg3: !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>)
                                     -> (!migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>,
                                         !migraphx.shaped<1x4x2x64xf32, 512x128x64x1>) {
    %0, %1 = migraphx.attention %arg0, %arg1, %arg2
      pre_softmax_inputs(%arg3 : !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>) {
      ^bb0(%qk: !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>,
           %s: !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>):
        %scaled = migraphx.mul %qk, %s
          : <1x4x2x64x64xf16, 32768x8192x4096x64x1>, <1x4x2x64x64xf16, 32768x8192x4096x64x1>
          -> <1x4x2x64x64xf16, 32768x8192x4096x64x1>
        migraphx.yield %scaled : !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>
      } softmax_type = f32 features = splitkv splitKV = 2
      : <1x4x64x64xf16, 16384x4096x64x1>, <1x4x64x128xf16, 32768x8192x128x1>, <1x4x128x64xf16, 32768x8192x64x1>
      -> <1x4x2x64x64xf16, 32768x8192x4096x64x1>, !migraphx.shaped<1x4x2x64xf32, 512x128x64x1>
    return %0, %1 : !migraphx.shaped<1x4x2x64x64xf16, 32768x8192x4096x64x1>, !migraphx.shaped<1x4x2x64xf32, 512x128x64x1>
  }

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a first-class migraphx.attention op to represent attention variants (via a composable feature bitmask) in the MIGraphX dialect, reducing downstream pattern-matching complexity and enabling clearer lowering paths for host (decompose) vs GPU (lower to rock.attention).

Changes:

  • Add migraphx.attention + migraphx.yield, along with attention feature flags and verifier coverage (valid/invalid tests).
  • Add host-side decomposition of migraphx.attention inside MIGraphXTransform and a new GPU lowering pass migraphx-attention-to-rock.
  • Integrate the new pass into pipelines and add initial E2E + conversion + C API construction tests.

Reviewed changes

Copilot reviewed 38 out of 38 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
mlir/test/rocmlir-driver/pipelines.mlir Updates expected high-level pipelines to include migraphx-attention-to-rock.
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-basic.mlir Adds E2E attention test case (basic).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-softmax-f32.mlir Adds E2E attention test case (softmaxType f32).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-lse.mlir Adds E2E attention test case (LSE output).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-scale.mlir Adds E2E attention test case (pre-softmax scale/bias region via migraphx.yield).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-causal.mlir Adds E2E attention test case (causal).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-causal-scale.mlir Adds E2E attention test case (causal + pre-softmax body).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache.mlir Adds E2E attention test case (kvcache).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal.mlir Adds E2E attention test case (kvcache + causal).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal-prefix.mlir Adds E2E attention test case (kvcache + causal + prefix offset).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-causal-sliding-window.mlir Adds E2E attention test case (kvcache + causal + sliding window).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-kvcache-scale.mlir Adds E2E attention test case (kvcache + pre-softmax body).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-splitkv.mlir Adds E2E attention test case (splitKV).
mlir/test/fusion/pr-e2e/migraphx-attention/mixr-attention-gqa.mlir Adds E2E attention test case (GQA shapes).
mlir/test/Dialect/MIGraphX/ops.mlir Adds parsing/printing coverage for new ops/attrs and attention variants.
mlir/test/Dialect/MIGraphX/invalid.mlir Adds verifier negative tests for attention feature/shape/operand constraints.
mlir/test/Conversion/MIGraphXToTosa/migraphx-to-tosa-preserves-rock-attention.mlir Ensures MIGraphXToTosa doesn’t rewrite rock.attention regions produced earlier.
mlir/test/Conversion/MIGraphXAttentionToRock/attention-to-rock.mlir Adds conversion tests for --migraphx-attention-to-rock.
mlir/test/Conversion/MIGraphXAttentionDecompose/attention-decompose.mlir Adds host decomposition tests for attention variants and feature combinations.
mlir/test/CAPI/mixr_attention.c Adds C API test that constructs migraphx.attention ops.
mlir/test/CAPI/CMakeLists.txt Builds/links new mlir-mixr-attention-test.
mlir/lib/Dialect/MIGraphX/Transforms/MIGraphXTransform.cpp Implements host-side migraphx.attention decomposition (non-kernel funcs).
mlir/lib/Dialect/MIGraphX/Pipeline/Pipeline.cpp Inserts MIGraphXAttentionToRock into the high-level pipeline before TOSA/Linalg lowering.
mlir/lib/Dialect/MIGraphX/Pipeline/CMakeLists.txt Links in new MLIRMIGraphXAttentionToRock library.
mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp Adds AttentionOp verifier + feature dependency checks.
mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosaPass.cpp Marks Rock dialect / rock.attention recursively legal to preserve nested ops.
mlir/lib/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.cpp New lowering pass from migraphx.attention to rock.attention for kernel funcs.
mlir/lib/Conversion/MIGraphXAttentionToRock/CMakeLists.txt Adds conversion library target for MIGraphXAttentionToRock.
mlir/lib/Conversion/CMakeLists.txt Adds new conversion subdirectory.
mlir/lib/CAPI/Dialect/MIGraphX.cpp Adds C API builder helper rocmlirMIGraphXAttentionCreate.
mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphXTypes.td Adds AttentionFeatures bit-flag enum attr.
mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td Adds migraphx.yield, migraphx.attention, and applies Elementwise trait to elementwise ops.
mlir/include/mlir/Conversion/RocMLIRPasses.td Declares migraphx-attention-to-rock pass.
mlir/include/mlir/Conversion/RocMLIRPasses.h Exposes new conversion header.
mlir/include/mlir/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.h Adds public pass declaration header.
mlir/include/mlir-c/Dialect/MIGraphX.h Bumps C API version and declares attention builder + feature bit macros.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp Outdated
Comment thread mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp Outdated
@umangyadav umangyadav self-assigned this Mar 27, 2026
@dhernandez0
Copy link
Copy Markdown
Contributor

dhernandez0 commented Mar 31, 2026

Some general comments/questions:

  • why do we need rocmlirMIGraphXAttentionCreate? can't they just send IR the same way they do it for conv/gemm?
  • why limit preSoftmaxBody to elementwise only? I think they might have reshapes inside the presoftmax currently. This might be too strict.
  • IMO attention features enum adds complexity (what happens if they send features without kv-cache but they enable currentSeqLen?), is there a reason we need it? we can detect the features by what optional params they send (currentSeqLen, prefixOffset, splitKV, slidingWindowSize). The only one that doesn't have an associated param is causal, we solve this by just having a boolean param causal in rock.attention.
  • AttentionDecompose is a lot of code just for cpu lowering which is not really relevant for us. I wonder if there's an easier way to do this, such as using some existing dialect (torch/aten), converting migraphx.attention -> aten.attention?, and using their lowering to linalg?

umangyadav and others added 30 commits May 3, 2026 13:06
Two coupled fixes for the host attention decompose contract:

- Convert qk to softmaxType BEFORE applying masks. Mask -inf injection
  goes through getNegInfAttr, which asserts on integer types, so i8
  Q/K with a mask (e.g. kvcache) used to crash. The verifier already
  guarantees softmaxType is set whenever the value entering softmax
  doesn't match V's element type, so the converted type is always one
  of the float types in AttentionVTypes - safe for masks.

- Add a verifier rule requiring result.elementType == V.elementType.
  The host decompose's second GEMM and rock.attention's gemm1 both
  produce in V's type; allowing a different result type let bad IR
  through that the lowering couldn't honor without a final convert.
  Callers wanting a different output dtype must convert downstream.

Tests: new negative test for the verifier rule, new E2E regression
for i8 Q/K + kvcache (the formerly-crashing case), and CHECK-line
updates in attention-decompose.mlir for the reordered convert.

Co-authored-by: Cursor <cursoragent@cursor.com>
The previous verifier allowed Q's leading dims to be divisible by K/V's
at any leading position, but both lowering paths only implement GQA on
the heads axis (dim 1) of rank-4 tensors:

- Host decompose's broadcastForGQA reads shape[1] for numHeadsKV.
- GPU lowering's getNumHeads returns getDimSize(1) only for rank == 4.

So a Q [4, 2, ...] / K [2, 2, ...] pair (batch differs by a factor of
2) used to verify but mis-lower silently: nothing broadcasts K/V across
the batch axis, and the first GEMM ends up with a shape mismatch.

Encode the contract directly: K and V leading dims must match, then
Q's leading dims must equal K's exactly except possibly at dim 1 of a
rank-4 op, where Q may be an integer multiple of K. Producers wanting
to pack extra leading dims should collapse them into the batch dim
before constructing the op.

Adds attention_gqa_batch_dim_divisible_rejected as a regression test;
all existing GQA tests (heads-only divisibility, rank-3 rejection,
rank-5 rejection, K-vs-V mismatch) keep their original diagnostics.

Co-authored-by: Cursor <cursoragent@cursor.com>
- Bump Copyright (c) 2025 -> 2026 in MIGraphXAttentionToRock.{h,cpp};
  both files were added in 2026 (commit 20f12ba, 2026-03-25) but
  the headers carried a stale 2025 year.
- Drop the lone "TODO: Add FP8 as supported data type" in MIGraphX.td;
  the AttentionQKTypes / AttentionVTypes / AttentionLseTypes lists
  themselves document what is supported, matching the pattern already
  used for the V and LSE type lists. Reword the preceding comment to
  note that FP8 is not currently supported.

Co-authored-by: Cursor <cursoragent@cursor.com>
The preSoftmaxBody verifier groups three ops as "skip the integer
operand check": dequantizelinear, convert, and where. The first two
take integer input by design, but the comment for where says only
operand 0 (the i8 boolean mask) may be integer; operands 1 and 2 are
the selected branches and feed select-style scalar arith on float
values, so they must still be float. The previous code skipped checks
for ALL where operands, so a where with i32 branches in the body
silently verified and would later crash the scalar lowering with an
arith.select-on-i32 mismatch.

Split the skip into two cases - skip everything for dequantize/convert,
skip only operand 0 for where - and add a regression test
attention_pre_softmax_where_int_branches that exercises a where with
i32 branches and expects the operand-1 diagnostic.

Co-authored-by: Cursor <cursoragent@cursor.com>
GridwiseGemmToBlockwise has eight invertTransforms call sites; seven
return a clean op.emitError when the inversion fails, but the splitKV
preSoftmax path (added later) was the one outlier that asserted on
success and called .value() unconditionally. Match the surrounding
style so a failure surfaces as a diagnostic in release builds instead
of crashing or producing garbage IR.

Co-authored-by: Cursor <cursoragent@cursor.com>
Host MIGraphXAttentionDecompose has explicit splitKV+kvcache coverage
(decompose_splitkv_kvcache_global_indices), but the GPU side had none:
attention_splitkv and attention_splitkv_presoftmax in
attention-to-rock.mlir don't use kvcache, and the splitkv* PR E2E
tests don't supply currentSeqLen. Add:

- attention_splitkv_kvcache lit test: structural check that the
  rock.attention picks up both currentSeqLen and splitKV = 2.
- mixr-attention-splitkv-kvcache.mlir clone-verifier E2E test for
  the KV-cache-decode shape (seqQ=1, seqK=128, splitKV=2). The
  result/LSE are split-space (5D/4D), so the verifier compares
  per-chunk partials directly; currentSeqLen is constrained to
  [80, 128) so every chunk has at least 16 valid keys (a fully
  masked chunk would produce -inf LSE / undefined partial output and
  the per-chunk comparison would fail even though the merged final
  result would be correct).

Co-authored-by: Cursor <cursoragent@cursor.com>
The pass is already added to a func-nested PM in
mlir/lib/Dialect/MIGraphX/Pipeline/Pipeline.cpp, and its
runOnOperation walked all func::FuncOps under whatever root it got
handed. Anchoring the pass declaration to func::FuncOp makes the
contract explicit (matches MIGraphXToTosaPass) and lets the body
collapse to a direct getOperation() + rock.kernel filter.

Co-authored-by: Cursor <cursoragent@cursor.com>
The C API helper used to silently accept null Q/K/V (producing
"missing operand" diagnostics on the parsed op), null
preSoftmaxElemWiseInputs with count > 0 (NULL deref), and negative
splitKV / slidingWindowSize (silently dropped instead of attached).
Add asserts at the top of the helper for the documented contract -
the header doc already says inputs are required and that splitKV is
"0 or 1 = omit", so this just enforces what's already promised.

Co-authored-by: Cursor <cursoragent@cursor.com>
Two attention E2E tests had RMS/relDiff thresholds far looser than the
actual host-vs-GPU divergence they observe. Loose thresholds hide
regressions: a real numerical drift would still pass the test if it
stayed under the wide bound. Both originated from copy-paste defaults
and were never measured against the actual diff their inputs produce.

mixr-attention-kvcache-causal-sliding-window-clamp.mlir
  Was RMS = 0.02, relDiff = 0.1 (copied from the parent
  kvcache-causal-sliding-window test). The clamp test pins
  currentSeqLen to [0, 2] with windowSize = 4 and works on a tiny
  4-element output with at most 2 valid keys per row, so host and GPU
  agree bit-exactly with the fixed seed (passes at 1e-9). Tighten to
  the standard 0.0005 / 0.0005 thresholds used by the rest of the
  attention E2E suite, leaving a small margin for harmless
  deterministic codegen changes (e.g. a different mfma intrinsic) but
  closing the 40-200x window the loose bounds left open. Add a comment
  recording the empirical bit-exactness and the rationale for the
  tightened (rather than 1e-9) bound.

mixr-attention-mixed-q-bf16-v-f32.mlir
  Was RMS = 0.01, relDiff = 0.05; empirically the test produces RMS
  ~3.3e-3 and worst-case relDiff ~3.9e-2 on the fixed seed. The
  relDiff floor is set by bf16's ~3 decimal mantissa digits in the
  first GEMM and by migraphx.dot's bf16/operand-type accumulator in
  the host's second GEMM (called out in commit dc0ddb3's
  message). Tighten RMS from 0.01 to 0.005 (1.5x measured); keep
  relDiff = 0.05 (1.3x measured) since worst-case relDiff is what
  varies most with input distribution. Add a comment documenting both
  bf16 error sources so future readers don't try to tighten further
  without widening the host accumulator first.

Both changes verified by ninja check-rocmlir; all 35 attention E2E
tests pass.

Made-with: Cursor
MIGraphXToTosaPass references three dialects in its conversion target
that were not previously declared as dependent dialects:
  - rock (addLegalDialect<rock::RockDialect>,
          markOpRecursivelyLegal<rock::AttentionOp>),
  - arith (addDynamicallyLegalDialect<arith::ArithDialect>),
  - tensor (addLegalDialect<tensor::TensorDialect> in the boundary
            conversion).

When the pass is scheduled standalone -- e.g. an off-tree tool that
runs migraphx-to-tosa before MIGraphXAttentionToRock has produced any
rock.attention -- the IR being converted may not yet contain ops in
those dialects, so they have to be loaded explicitly via
dependentDialects rather than relying on parse-time loading. Existing
in-tree pipelines already load these dialects via siblings, so this
is a no-op for the test suite (still 1372/1372 passing) but makes the
pass safe to use in isolation.

linalg / memref / math / bufferization are intentionally NOT added:
the pass only marks ops in rock.attention recursively legal, never
references those dialects' types or creates ops in them, and ops
already living inside an in-flight rock.attention imply their dialect
was loaded at parse time.

Made-with: Cursor
The previous round of C API hardening guarded the entry-point arguments
with assert(), which compiles to a no-op under NDEBUG. Release-build
callers that violated the contract (NULL queries / keys / values, NULL
preSoftmaxElemWiseInputs with a non-zero count, negative splitKV /
slidingWindowSize / numPreSoftmaxInputs, NULL resultType) would still
fall through to the same confusing failure modes the assertions were
meant to prevent: NULL deref while iterating the inputs array, "no Q
operand" diagnostics on a half-built op, or splitKV silently dropped.

Replace the asserts with `if (...) return reject(msg)` checks that fire
uniformly in debug and release builds, write a
"rocmlirMIGraphXAttentionCreate: <reason>" line to stderr (matching the
existing `mlirMIGraphXAddBackendPipeline` diagnostic style in this
file), and return a null MlirOperation. Callers can detect the failure
with mlirOperationIsNull. The order of checks now also catches a
negative numPreSoftmaxInputs before it is used as a loop bound, fixing
a subtle order-of-evaluation hazard in the previous assert chain.

Document the new contract in mlir-c/Dialect/MIGraphX.h: list every
condition that returns null and clarify that all *other* invariants
(operand element types, shape compatibility, feature/operand
consistency) are still enforced by the AttentionOp verifier rather
than by the C API. This matches what the implementation actually does
and lets callers distinguish "I gave it garbage" (null return,
detectable here) from "the verifier rejected my IR" (verified op
returned but parsing/verification fails later).

Existing CAPI tests (mlir/test/CAPI/mixr_attention.c) continue to pass
unchanged because they always provide valid inputs.

Made-with: Cursor
isAllowedInPreSoftmaxBody and lowerMIGraphXElementwiseToScalar already
handle migraphx.ceil -> math.ceil and migraphx.floor -> math.floor, but
no FileCheck verified that the dispatcher actually emits those math ops.
A regression that dropped either case from the dispatcher (or that
silently changed ceil/floor to a different math op) would still pass
the verifier-allowlist parity assert and walk past the existing
attention_extended_body_ops test, which exercised every other op in the
allowlist except these two.

Extend attention_extended_body_ops to chain ceil after abs and floor
after ceil, and add `// CHECK: math.ceil` / `// CHECK: math.floor`
between the math.absf and math.exp checks. The body still type-checks
end-to-end (all f32, same shape as the rest of the chain) so the test
remains a single compile-only regression net for the full closed set
of scalar lowerings the verifier is allowed to admit.

Update the comment above the test to (a) name ceil and floor
explicitly in the op list and (b) state outright that this test is
the regression net for the verifier/dispatcher contract, so a future
op addition has an obvious place to land its CHECK line instead of
silently accumulating coverage debt.

Made-with: Cursor
The previous round of input-contract enforcement (commit 2276637)
guarded queries / keys / values, the preSoftmaxElemWiseInputs array,
splitKV / slidingWindowSize / numPreSoftmaxInputs signs, and resultType
but missed preSoftmaxBody. The body is unwrapped and dereferenced
unconditionally further down (`mlir::Region *body = unwrap(...);
if (body->empty())`), so a caller passing a default-initialized
`(MlirRegion){NULL}` -- forgetting to call mlirRegionCreate() -- would
still crash on a null deref in both debug and release builds, defeating
the point of the rest of the contract checks.

Add a `mlirRegionIsNull(preSoftmaxBody)` check to the front-loaded
contract block alongside the other null guards, returning the same
null MlirOperation + stderr diagnostic as everything else. Steer
callers towards the right idiom in the diagnostic string itself
("use mlirRegionCreate() for an empty body") so the message is
self-explanatory without needing to read the header.

Update mlir/include/mlir-c/Dialect/MIGraphX.h to list this case in
the documented contract alongside the other null/negative rejections.

Existing CAPI tests in mlir/test/CAPI/mixr_attention.c always pass
the result of mlirRegionCreate() so they continue to pass unchanged;
the full RocMLIR test suite is unaffected.

Made-with: Cursor
mlir/test/CAPI/mixr_attention.c only exercised successful builders, so
the input-contract checks added in commits 7323bea, 2276637,
and 94e9f8a were not directly verified end-to-end. A future
refactor that compiled out a check, reordered the rejects, or changed
the diagnostic wording would slip past the existing tests and only
surface as a downstream NULL deref or confusing "no Q operand" error.

Add a single testAttentionRejectsInvalidInputs case that builds valid
Q/K/V scratch values once and reuses them across three sub-cases:
  1. null queries operand (Q/K/V null guards),
  2. null preSoftmaxBody (the bug fixed in 94e9f8a -- the previous
     code would crash on body->empty() in NDEBUG),
  3. negative splitKV (representative invalid scalar; pre-fix this was
     silently dropped by the writer's `splitKV > 1` guard).

Each sub-case asserts mlirOperationIsNull on the return value and
FileCheck pins the exact "rocmlirMIGraphXAttentionCreate: ..." stderr
diagnostic, including the "use mlirRegionCreate() for an empty body"
hint. Unowned MlirRegions from the rejected calls are explicitly
mlirRegionDestroy()'d so the test stays valgrind-clean.

The new case slots in as the 11th and final test in main() so its
output appears at the end of the FileCheck stream and doesn't perturb
the line ordering of any existing positive test.

Made-with: Cursor
Three independent rules and two NFC simplifications, all in the same
verify() pass:

1. Integer Q with an empty preSoftmaxBody is rejected. The first GEMM
   for integer Q is migraphx.quant_dot (i32 output), and the host
   decompose used to emit a bare migraphx.convert (raw bit-width cast)
   from i32 to the softmax type. Casting a quant_dot accumulator
   directly to f32 feeds enormous values to softmax and produces
   effectively-one-hot garbage. The legitimate path is a dequantize op
   in the body that applies the user's scale/bias, so require the body.
   The rule is independent of softmaxType; setting softmaxType only
   picks the float type, it does not synthesize the missing scale.

2. Zero-sized dims are rejected alongside dynamic dims. Every shape
   calculation downstream (broadcastForGQA, getNumHeads, splitKV %
   seqK, the GQA qd % kd divisibility check) is undefined or division
   by zero on a zero dim, and a zero-sized attention has no semantic
   meaning anyway. Folded into the existing rejectDynamic helper.

3. The GQA loop is gated on gqaActive instead of always iterating with
   an early continue, and both K/V and Q/K leading-dim loops use
   zip_equal so a future regression of the leading-dim count check
   asserts loudly in debug builds rather than silently truncating.

Adds three invalid.mlir tests (integer-Q empty body, integer-Q empty
body with softmax_type set, zero-sized K dim) and rewrites the
existing attention_i8_qk_missing_softmax_type test to match the new
diagnostic. Other invalid tests are untouched and still pass.

Co-authored-by: Cursor <cursoragent@cursor.com>
The LSE post-processing branch ended with:

  if (currentLseElemTy != lseOutputType.getElementType())
    lseValue = migraphx::ConvertOp::create(...);

But currentLseElemTy is softmaxElemType, which is computed as
op.getSoftmaxType().value_or(vType.getElementType()), and the verifier
requires lse's element type to equal exactly that effective softmax
type. The two are therefore always equal and the convert never fires.
Confirmed by inspecting the IR produced by --migraphx-transform on a
sample LSE op (no trailing convert is emitted).

Also drop the local reshapedLseType variable now that we can reshape
straight into lseOutputType.

Pure cleanup -- if the verifier rule is ever loosened, the convert can
come back; but until then the dead conditional is just noise that can
absorb a real type mismatch silently.

Co-authored-by: Cursor <cursoragent@cursor.com>
After the rest of the input-validation block was tightened, location
was the only required argument that the builder still dereferenced
without first checking. mlirLocationGetContext is called immediately
below the validation block, and mlirOperationStateGet and the YieldOp
builder for the empty-body path also dereference it; a default-
initialised MlirLocation would crash in release builds the same way
the inputs array used to before the rest of the contract was enforced.

Reject it up front with the same diagnostic shape as the other
checks, document it in the header alongside the rest of the contract,
and add a Case 4 to testAttentionRejectsInvalidInputs that passes
(MlirLocation){NULL} and asserts mlirOperationIsNull plus the exact
stderr line.

Co-authored-by: Cursor <cursoragent@cursor.com>
The mixr-attention-kvcache, mixr-attention-kvcache-causal-sliding-window,
and mixr-attention-kvcache-scale E2Es all run softmax in f16 (no
explicit softmax_type, so the lowering defaults to V's element type)
and use loose RMS / relDiff thresholds (0.02 / 0.1 and 0.03 / 0.2)
to absorb the resulting precision loss. The thresholds were
undocumented, so a future maintainer trying to tighten them had no
way to know whether they were hiding a real bug or honest f16
imprecision.

Annotate each with:
  * the source of the looseness (f16 softmax + reduce_sum + exp +
    recip eat the ~3 mantissa digits f16 has),
  * the empirical worst-case observed,
  * a pointer to the companion tight-threshold test in this directory
    that exercises the same masking math at standard 0.0005 thresholds
    (mixr-attention-kvcache-scale-lse for the scale and base kvcache
    cases, mixr-attention-kvcache-causal-sliding-window-clamp for the
    masking case),
  * the conditions under which tightening would be possible (set
    softmax_type = f32, change what's tested).

Pure documentation -- no behavior change. The companion tests already
guarantee that a regression in the masking lowering would not silently
hide behind these loose thresholds.

Co-authored-by: Cursor <cursoragent@cursor.com>
The three feature flags i8 Q/K, splitKV, and kvcache each individually
have an E2E test in this directory, but the combined path (the most
complex single configuration the op supports) was not exercised. The
cross-product is non-trivial because every layer has to coordinate:

  * host decompose: emit migraphx.quant_dot with i8 operands, reshape K
    and broadcast Q for splitKV in QK-shape integer space, inline a
    dequantize body operating in split-space [B, splitKV, seqQ,
    seqK/splitKV], apply the kvcache mask using global key indices
    spanning splitKV chunks, and compute per-chunk LSE in f32.
  * GPU lowering: mirror that flow through gridwise_attention_accel
    with the splitKV transform inversion in postProcessFirstGemm, and
    emit the dequantize-style body op as a sitofp + mulf in the
    linalg.generic body of rock.attention.

--verifier clone confirms the two paths agree to the standard tight
0.0005 thresholds. currentSeqLen is constrained to [80, 128) so every
splitKV chunk has at least 16 valid keys, mirroring the existing
mixr-attention-splitkv-kvcache test (a fully-masked chunk would
produce -inf per-chunk LSE and the per-chunk verifier comparison
would fail even though the merged final result would still be
correct).

Co-authored-by: Cursor <cursoragent@cursor.com>
git-clang-format wants the bracketed list packed onto fewer lines rather
than one entry per line. Apply the formatter's preferred layout so the
clang-format premerge gate stays clean. No semantic change.

Co-authored-by: Cursor <cursoragent@cursor.com>
…ks, splitKV diagnostic

Five focused defensive-coding improvements to migraphx.attention's lowering
chain that surfaced during a deeper review:

* MIGraphXAttentionToRock: assert that any numHeadsQ != numHeadsKV path goes
  through a 4D Q. The verifier already enforces this, but getNumHeads' rank<4
  fallback returns 1 for both Q and K, so a future verifier loosening (e.g.
  allowing MQA on rank-3) would silently produce a 1-head kernel for a true
  GQA workload. The assert trips loudly before that happens.

* MIGraphXAttentionToRock: assert that the rebuilt rock.attention
  preSoftmaxBody block has exactly one more arg than the source block (the
  trailing output memref). Documents and enforces the (QK, extras..., out)
  contract against future op-signature drift.

* GridwiseGemmToBlockwise: include the splitKV-transform chain length in the
  invertTransforms failure diagnostic so triage doesn't have to grep for
  which chain failed.

* CAPI rocmlirMIGraphXAttentionCreate: mirror the verifier's orphan-attr /
  orphan-operand checks at the API boundary so callers get
  "'splitKV' attribute requires feature 'splitkv'" (and the equivalents for
  slidingWindowSize, currentSeqLen, prefixOffset) before any IR is built,
  instead of an opaque op-verifier error after the fact. Update the header
  Doxygen and add four new negative cases (Cases 5-8) to the CAPI test.

Co-authored-by: Cursor <cursoragent@cursor.com>
…ines

* verifyAttentionLeadingDimsOperand: drop the unnecessary
  llvm::make_range(shape.begin(), shape.end()) wrap. Diagnostic::operator<<
  has a generic range overload that calls appendRange with ", " as the
  delimiter, so streaming the ArrayRef<int64_t> directly produces the same
  output as the qBatch print just to its left.

* attention-decompose tests: pin the f16/f32 element-type transitions on
  each migraphx.convert with CHECK-SAME so a future regression that drops
  a widen, drops a narrow, or turns a convert into a no-op (f16 -> f16)
  would break the test instead of silently passing on op presence alone.
  Covers decompose_with_softmax_type (Q*K widen, V widen, output narrow)
  and decompose_widened_second_gemm (QK widen, output narrow).

Co-authored-by: Cursor <cursoragent@cursor.com>
The pre-existing bounds-check builder for rock.transforming_for emitted
both the `sge 0` lower-bound check and the `slt bound` upper-bound check
for every padded dimension. For Pad{0, N} (no left padding) the `sge 0`
compare is unconditionally true because the affine map is identity, and
for Pad{N, 0} (no right padding) the `slt bound` compare is
unconditionally true. Relying on the canonicalizer to drop these isn't
free: the redundant compare sits in the validity-check inner loop body
of every padded transform until canonicalisation runs (or doesn't, if
the bounds aren't statically known to it), bloating the IR without need.

Gate the emission directly on the pad parameters: emit `sge` only when
the left padding is non-zero, emit `slt` only when the right padding is
non-zero. Embeds still emit both because they can wrap in either
direction.

Update bounds_check_pad to assert via CHECK-NOT that the redundant `sge`
no longer appears, and add a new bounds_check_pad_left_only case
covering the dual side: Pad{2, 0} should produce only `sge` and no
`slt`.

Co-authored-by: Cursor <cursoragent@cursor.com>
Two new files added in this branch had their LLVM-style file-header lines
broken across two lines because the descriptive title pushed the line
past the 80-column limit and clang-format wrapped it:

  //===- MIGraphXAttentionToRock.cpp - Lower migraphx.attention to rock
  //------===//

That's not the recommended LLVM banner shape and the trailing "----===//"
on the next line confuses doc tooling and human readers alike. Shorten
both titles to fit on one line, move the descriptive sentence into a
proper file-level comment block underneath, and use the standard
title-padded banner. Also cross-reference the polarity contract from
MIGraphXAttentionToRock.h's banner.

Co-authored-by: Cursor <cursoragent@cursor.com>
The host-side AttentionDecompose (in migraphx-transform) and the
kernel-side AttentionToRockPattern (in migraphx-attention-to-rock) use
opposite rock.kernel polarity guards so they can run back-to-back in
addHighLevelPipeline without stepping on each other. The contract was
implicit in the code; this commit makes it explicit and pins it with a
test.

* migraphx.attention's tablegen description now ends with a
  "Lowering polarity" paragraph spelling out that rock.kernel functions
  go through migraphx-attention-to-rock and everything else through
  migraphx-transform.

* MIGraphXAttentionToRockPass::runOnOperation now has a multi-line
  comment explaining why the rock.kernel guard is there, what flipping
  it would break, and pointing at the new end-to-end polarity test.

* MIGraphXToTosaPass's recursively-legal rationale is expanded to call
  out which dialects show up inside rock.attention's preSoftmaxBody
  region (linalg, memref, arith, math) and to explain why the pass's
  dependentDialects list doesn't have to enumerate them.

* New Lit test attention-pipeline-polarity.mlir runs both passes in
  sequence on a host function (no rock.kernel) and a kernel function
  (rock.kernel set), and asserts the host side decomposes to dot/softmax/dot
  with no rock.attention emitted, while the kernel side emits exactly one
  rock.attention with no leftover migraphx.attention.

* .gitignore: drop the leading slashes from plans/, scratch/, notes/ so
  contributors who keep their notes nested under e.g. mlir/plans/ also
  have them ignored, matching the spirit of the workspace-hygiene rule.

Co-authored-by: Cursor <cursoragent@cursor.com>
…rms producers

The CAPI Doxygen for rocmlirMIGraphXAttentionCreate enumerates seven
contract violations, but only six were exercised by the negative
test. Add cases 9-11 covering the remaining three: a negative
numPreSoftmaxInputs count, a NULL preSoftmaxElemWiseInputs array with
a positive count, and a null resultType. With these cases in place,
every documented reject path now has a FileCheck-pinned diagnostic in
mixr_attention.c, so future edits to the validation block (or the
header doc) cannot silently drop one.

Separately, expand the doc-comment for rock.attention's
preSoftmaxHasSplitKVTransforms attribute to reflect both producers
that can set it: DetectFlashDecoding (lifting splitKV out of
rock.transform ops on the body inputs) and MIGraphXAttentionToRock
(when lowering a migraphx.attention with both splitKV > 1 and a
non-empty preSoftmaxElemWiseInputs). Previously the doc only
mentioned DetectFlashDecoding, which made the second producer feel
implicit. Spell out the post-split layout the attribute implies and
the consequence in postProcessFirstGemm so future contributors don't
have to reverse-engineer the contract from GridwiseGemmToBlockwise.

Co-authored-by: Cursor <cursoragent@cursor.com>
Pulls together the cross-cutting contracts that several files (verifier,
host AttentionDecompose, GPU MIGraphXAttentionToRock, C API, rocmlir-gen)
all have to agree on, so a new contributor doesn't have to reverse-engineer
them from the code:

- Block-argument layout, allowed-ops list, and the verifier/lowering
  lock-step rule for preSoftmaxBody.
- The dequant-in-body rule for integer Q (why softmaxType alone is not
  enough).
- The rock.kernel polarity that splits handling between the host and
  GPU passes.
- How perf_config flows from migraphx.attention to rock.attention so
  tuningRunner.py hints reach the kernel generator.
- The checklist for adding a new feature flag.

Doc-only; no source changes.

Co-authored-by: Cursor <cursoragent@cursor.com>
Two correctness gaps surfaced in review where verifier-clean IR could
silently lower to wrong numerics:

1. Yield element type was unchecked. A body with a float-only op (whose
   result is unused) plus a yield of the i32 QK block-arg would pass
   both the empty-body integer-Q check (body has ops) and the body-op
   operand-type check (the unused op has float operands). The host
   decompose would then emit a bare migraphx.convert from i32 to
   softmaxType, feeding raw quantized accumulator values to softmax
   and producing one-hot garbage. Require the yield's element type to
   be float so producers must spell out their dequantize.

2. Q/K/V rank > 4 was unchecked outside the GQA path.
   MIGraphXAttentionToRock's getNumHeads reads dim 1 of a rank-4 shape
   and falls back to 1 for any other rank, which would silently
   produce a 1-head kernel for a real multi-head workload. Tighten
   the existing rank check to require rank 3 or 4 (the GQA-specific
   rank-must-be-4 branch already covers the GQA path; this closes the
   non-GQA hole).

Both gaps come with a negative test in invalid.mlir; the existing
GQA-rank-5 test is updated to match the new (earlier-fired) error
message and re-purposed as belt-and-braces for the general check.

Co-authored-by: Cursor <cursoragent@cursor.com>
Cleanup-only changes from the most recent review pass; no behavioural
change in any production codepath.

- AttentionToRockPattern: drop a dead `origLseType &&` conjunct from
  the LSE expand-shape gate. The whole block is already inside
  `if (op.getLse())`, which is the only path that initialises
  `origLseType`, so the extra null check could never fail.

- MIGraphXToTosaPass: keep the explicit `addLegalOp<rock::AttentionOp>`
  next to the recursive-legal mark and document why it must stay.
  Dropping it (relying on `addLegalDialect<rock::RockDialect>` alone)
  causes `markOpRecursivelyLegal` to lose its anchor and the body
  ops get walked into and crash the conversion -- caught by the
  existing migraphx-to-tosa-preserves-rock-attention test.

- Pipeline.cpp: revert a needless if/else expansion to the original
  `addPass(cond ? a : b)` ternary so the diff vs develop stays small.

- attention-pipeline-polarity.mlir: relax the kernel-side CHECK-SAME
  from `attributes {arch = "", rock.kernel}` to just `rock.kernel`,
  so the test stops being brittle to attribute-ordering changes (and
  to future additions next to `arch`).

- mixr_attention.c: align the license header with the other new
  attention files in this PR ("rocMLIR Project" instead of "LLVM
  Project") and add the copyright line that
  AttentionUtils.h / MIGraphXAttentionToRock.cpp already use.

Co-authored-by: Cursor <cursoragent@cursor.com>
Documentation, source-comment, and test-comment cleanup only. No
behavioural change in any production codepath; the verifier's
emitOpError diagnostic strings (and therefore every lit
expected-error checked against them) are byte-identical.

attention.md (mlir/docs/MIGraphX/):

- Quantized INT8 section reframed around an explicit
  Structural-vs-Semantic split. The verifier enforces "non-empty
  body that yields float" structurally; whether the body actually
  applies the user's quantization scale (via dequantizelinear vs
  bare convert) is producer responsibility. The previous claim
  that a bare convert i32 -> f32 is a "raw bit-width cast" was
  factually wrong (it lowers to sitofp/uitofp via
  convertScalarToDtype) and the previous "verifier closes every
  loophole" framing overstated what verifyOrphanAttr/operand can
  see. Updated in §4.9, §6 intro, §6.2, and §6.4 with the
  symmetry-rule qualifier (a convert-only body is structurally
  legal only when paired with at least one
  preSoftmaxElemWiseInput the convert silently ignores).

- Added a Contents section with all H2/H3 entries plus the H4
  subsections under §3.1 (splitkv).

- Fixed broken cross-reference anchor `#5-the-presoftmaxbody-contract`
  in three call sites; the actual heading is "## 5. The pre-softmax
  body contract (in detail)" so the anchor is now
  `#5-the-pre-softmax-body-contract-in-detail`.

- §3.1 splitKV: corrected the upper-bound implication
  "(1, 2, 4, ..., 128)" -> "(1, 2, 4, 8, ..., with no
  verifier-imposed upper bound)" since DetectFlashDecoding only
  checks llvm::isPowerOf2_64.

- §2.1 assembly format sketch: removed spurious commas between
  optional clauses and added the trailing `: <Q>, <K>, <V> -> <R>`
  type clause to match the real assemblyFormat in MIGraphX.td.

MIGraphX.cpp (mlir/lib/Dialect/MIGraphX/IR/):

- Two stale `// raw bit-width cast ... one-hot garbage` rationale
  blocks (the empty-body integer-Q branch around l.1015 and the
  yield-must-be-float branch around l.1080) replaced with the
  accurate sitofp/uitofp framing plus a pointer back to
  attention.md §6.2 for the structural-vs-semantic split.

invalid.mlir (mlir/test/Dialect/MIGraphX/):

- Same wording fix applied to the G3a / G3c test-case rationale
  comments. The `expected-error {{...}}` lines that lit checks
  against the verifier diagnostics are unchanged.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants