Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
c4a7ee3
add migraphx.attention op
umangyadav Mar 23, 2026
4ea302d
add migraphx.attention op using migraphx ops in presoftmaxbody
umangyadav Mar 24, 2026
1f3a188
add attention decompose pass
umangyadav Mar 24, 2026
20f12ba
add E2E loweing for migraphx.attention
umangyadav Mar 25, 2026
8e690ce
Formatting
umangyadav Mar 25, 2026
48438f1
add all features
umangyadav Mar 27, 2026
a1f23e3
add formatting
umangyadav Mar 27, 2026
58112c1
add return tyep for mgiraphx.yield
umangyadav Mar 27, 2026
7fb4556
formatting
umangyadav Mar 27, 2026
e138616
add some more tests
umangyadav Mar 27, 2026
f1c4b40
Review cycle
umangyadav Mar 27, 2026
d56a7cd
Use rock.kernel consistently in MIGraphX attention lowering.
umangyadav Apr 15, 2026
55084bb
add a fix to lib deps
umangyadav Apr 15, 2026
4d1cf9a
Some fixes
umangyadav Apr 30, 2026
72e7d93
[NFC] Ignore local plans/, scratch/, and notes/ directories
umangyadav Apr 30, 2026
f0fc503
Require explicit head broadcast for migraphx.attention seqLen/prefixO…
umangyadav Apr 30, 2026
cb4a701
Require slidingWindowSize attribute when sliding_window feature is set
umangyadav Apr 30, 2026
2b3adba
Support i8 Q/K via dequantize-in-body for migraphx.attention
umangyadav Apr 30, 2026
2c9f3f1
Require Q rank >= 4 for migraphx.attention GQA
umangyadav Apr 30, 2026
1ca017f
Validate preSoftmaxBody arg/yield types in migraphx.attention verifier
umangyadav May 1, 2026
dc0ddb3
Require softmaxType when softmax input doesn't match V's element type
umangyadav May 1, 2026
66ac8af
Add migraphx.attention E2E coverage mirroring pr-e2e/attention/padded*
umangyadav May 1, 2026
f2dc1a6
Clamp host sliding-window lower bound and tighten kvcache mask docs
umangyadav May 1, 2026
49efd39
Extend preSoftmaxBody scalar lowering and tighten verifier allowlist
umangyadav May 1, 2026
a648693
[NFC] Share migraphx.attention contracts via AttentionUtils.h
umangyadav May 1, 2026
aef2fac
[NFC] Simplify migraphx.attention verifier and host stride helpers
umangyadav May 1, 2026
82d9b82
Assert dispatcher / verifier allowlist parity and exercise sliding-wi…
umangyadav May 1, 2026
7b88f9f
Tighten migraphx.attention feature interactions
umangyadav May 1, 2026
514f9c9
[NFC] Clean up unused MIXR-to-tensor helper and host LSE reshape
umangyadav May 1, 2026
9ad89f5
Fix migraphx.attention splitKV lowering/decompose alignment
umangyadav May 1, 2026
1986544
Tighten migraphx.attention verifier across operand/attr edge cases
umangyadav May 1, 2026
b7f79c4
Document actual cause of splitkv-scale loose threshold (rock kernel bug)
umangyadav May 1, 2026
68fc74b
Refine splitkv-scale loose-threshold comment with grid-size detail
umangyadav May 1, 2026
c1127c2
Pinpoint splitkv-scale rock-side bug to postProcessFirstGemm read
umangyadav May 1, 2026
5eaeab9
Fix splitKV body-input read in postProcessFirstGemm
umangyadav May 2, 2026
049a068
Add splitKV+body+otherIns regression test for gridwise attention
umangyadav May 2, 2026
9b09acb
[NFC] Polish migraphx.attention helper construction
umangyadav May 2, 2026
be5408e
[NFC] Factor out shared shape and broadcast helpers in attention veri…
umangyadav May 2, 2026
d852889
Tighten migraphx.attention contract and widen host second GEMM
umangyadav May 2, 2026
276086b
Make C API attention helper produce verifier-clean IR
umangyadav May 2, 2026
42410d3
Fix migraphx.attention i8 mask decompose and reject result-type mismatch
umangyadav May 3, 2026
13fa5b2
Tighten migraphx.attention GQA to dim-1-on-rank-4 only
umangyadav May 3, 2026
af545f6
[NFC] Fix attention file copyright year and orphan TODO
umangyadav May 3, 2026
384aac7
Tighten migraphx.attention where-op body check to skip cond only
umangyadav May 3, 2026
71671d2
Convert splitKV invertTransforms assert to op.emitError
umangyadav May 3, 2026
2a641fb
Add splitKV+kvcache coverage for the rock.attention path
umangyadav May 3, 2026
3a70f7c
Anchor MIGraphXAttentionToRockPass to func::FuncOp
umangyadav May 3, 2026
7323bea
Harden rocmlirMIGraphXAttentionCreate input contract
umangyadav May 3, 2026
e63b7ec
Tighten loose thresholds in two attention E2E tests
umangyadav May 3, 2026
2bc70e5
[NFC] Declare rock/arith/tensor as MIGraphXToTosaPass dependentDialects
umangyadav May 3, 2026
2276637
Enforce rocmlirMIGraphXAttentionCreate input contract in release builds
umangyadav May 3, 2026
cfe5a69
Cover migraphx.ceil and migraphx.floor in attention scalar lowering test
umangyadav May 3, 2026
94e9f8a
Reject null preSoftmaxBody in rocmlirMIGraphXAttentionCreate
umangyadav May 3, 2026
e9eee32
Add CAPI negative test for rocmlirMIGraphXAttentionCreate contract
umangyadav May 3, 2026
b101d7a
Tighten migraphx.attention verifier across remaining edge cases
umangyadav May 3, 2026
b11c0cb
[NFC] Drop dead LSE convert in attention host decompose
umangyadav May 3, 2026
a1ecb96
Reject null location in rocmlirMIGraphXAttentionCreate
umangyadav May 3, 2026
3d321be
Document f16-softmax precision floor in three loose-threshold tests
umangyadav May 3, 2026
499b99b
Add i8 + splitKV + kvcache cross-product attention E2E test
umangyadav May 3, 2026
cdd6eb4
[NFC] clang-format MIGraphXToTosaPass dependentDialects list
umangyadav May 3, 2026
da07229
Tighten attention contracts: GQA assert, body arity, CAPI orphan chec…
umangyadav May 3, 2026
4c080a2
[NFC] Simplify diagnostic + tighten softmax_type widen/narrow CHECK l…
umangyadav May 3, 2026
66cc49c
Skip always-true bounds compares in pad/embed validity check
umangyadav May 3, 2026
d480606
[NFC] Repair clang-format-mangled headers in new attention files
umangyadav May 3, 2026
03fb918
Document attention pass polarity contract + add end-to-end polarity test
umangyadav May 3, 2026
dbbf99c
Pin remaining CAPI reject paths + clarify preSoftmaxHasSplitKVTransfo…
umangyadav May 4, 2026
3e9be5b
[NFC] Add developer doc for migraphx.attention preSoftmaxBody contract
umangyadav May 4, 2026
e149290
Tighten migraphx.attention verifier: yield float, reject rank > 4
umangyadav May 4, 2026
eabf4ed
[NFC] Address minor review nits in attention lowering and CAPI test
umangyadav May 4, 2026
103c5b9
[NFC] Clarify migraphx.attention quantization contract + add doc TOC
umangyadav May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
# Nested build directory
/build*

# Local plan/scratch notes (per workspace-hygiene rule). Match anywhere in
# the tree (no leading slash) so notes nested under e.g. mlir/plans/ also
# stay untracked.
plans/
scratch/
notes/

#==============================================================================#
# Explicit files to ignore (only matches one).
#==============================================================================#
Expand Down
2,028 changes: 2,028 additions & 0 deletions mlir/docs/MIGraphX/attention.md

Large diffs are not rendered by default.

61 changes: 61 additions & 0 deletions mlir/include/mlir-c/Dialect/MIGraphX.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ extern "C" {
// - mlirGetKernelAttrs() returns uint32_t[3] {block_size, grid_size,
// cluster_size} instead of uint32_t[2] {block_size, grid_size}.
// - Removed: mlirGetKernelInfo(), mlirMIGraphXAddApplicabilityPipeline().
// - Added: rocmlirMIGraphXAttentionCreate() for building migraphx.attention
// ops with variadic inputs, optional LSE, softmaxType, preSoftmaxBody,
// feature flags (kvcache, causal, prefix_offset, sliding_window, splitkv),
// currentSeqLen, prefixOffset, splitKV, and slidingWindowSize.
#define MLIR_MIGRAPHX_DIALECT_API_VERSION 5

typedef struct MlirMIGraphXBackendOptions {
Expand All @@ -38,6 +42,13 @@ typedef struct MlirMIGraphXBackendOptions {
int optLevel;
} MlirMIGraphXBackendOptions;

#define MLIR_MIGRAPHX_ATTENTION_NONE 0
#define MLIR_MIGRAPHX_ATTENTION_KVCACHE (1 << 0)
#define MLIR_MIGRAPHX_ATTENTION_CAUSAL (1 << 1)
#define MLIR_MIGRAPHX_ATTENTION_PREFIX_OFFSET (1 << 2)
#define MLIR_MIGRAPHX_ATTENTION_SLIDING_WINDOW (1 << 3)
#define MLIR_MIGRAPHX_ATTENTION_SPLITKV (1 << 4)

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MIGraphX, migraphx);

// Types
Expand Down Expand Up @@ -73,6 +84,56 @@ MLIR_CAPI_EXPORTED void mlirMIGraphXAddHighLevelPipeline(MlirPassManager pm);
MLIR_CAPI_EXPORTED bool
mlirMIGraphXAddBackendPipeline(MlirPassManager pm,
const MlirMIGraphXBackendOptions *opts);

// Op creation helpers

/// Creates a `migraphx.attention` operation.
///
/// \p queries, \p keys, \p values are the required Q, K, V operands.
/// \p preSoftmaxElemWiseInputs is an array of \p numPreSoftmaxInputs additional
/// operands for element-wise fusion before softmax (can be NULL if 0).
/// \p resultType is the MIXRShaped type of the attention result (required).
/// \p lseType is the MIXRShaped type of the optional log-sum-exp output; pass
/// a null type (via mlirTypeIsNull) to omit.
/// \p softmaxType is the optional element type for softmax computation; pass
/// a null type to omit.
/// \p preSoftmaxBody is a caller-created region for pre-softmax element-wise
/// ops. Pass an empty region (mlirRegionCreate()) for a no-op body.
/// Ownership of the region transfers to the created operation.
/// \p features is the bitwise-OR of MLIR_MIGRAPHX_ATTENTION_* flags (0 = none).
/// \p currentSeqLen is required when kvcache is set; pass null value to omit.
/// \p prefixOffset is required when prefix_offset is set; pass null to omit.
/// \p splitKV is the number of KV splits (0 or 1 = omit attribute).
/// \p slidingWindowSize is the window size (0 = omit attribute).
///
/// Contract violations are rejected with a stderr diagnostic and a null
/// MlirOperation return (check via mlirOperationIsNull). The same contract
/// is enforced in both debug and release builds. Specifically the function
/// returns a null op (and writes a "rocmlirMIGraphXAttentionCreate: ..."
/// line to stderr) if \p location is null, if any of \p queries, \p keys,
/// \p values is null, if \p numPreSoftmaxInputs is negative or
/// \p preSoftmaxElemWiseInputs is NULL when the count is positive, if
/// \p splitKV or \p slidingWindowSize is negative, if \p resultType is
/// null, or if \p preSoftmaxBody is null (use mlirRegionCreate() for the
/// no-body case rather than a default-initialized struct).
///
/// The feature/attribute and feature/operand pairings from the op verifier
/// are also enforced here so the diagnostic happens before any IR is
/// constructed: \p splitKV > 1 requires MLIR_MIGRAPHX_ATTENTION_SPLITKV in
/// \p features, \p slidingWindowSize > 0 requires
/// MLIR_MIGRAPHX_ATTENTION_SLIDING_WINDOW, a non-null \p currentSeqLen
/// requires MLIR_MIGRAPHX_ATTENTION_KVCACHE, and a non-null
/// \p prefixOffset requires MLIR_MIGRAPHX_ATTENTION_PREFIX_OFFSET. All
/// other invariants (operand element types, shape compatibility, the
/// missing-operand-required-by-feature direction, etc.) are still left to
/// the AttentionOp verifier.
MLIR_CAPI_EXPORTED MlirOperation rocmlirMIGraphXAttentionCreate(
MlirLocation location, MlirValue queries, MlirValue keys, MlirValue values,
intptr_t numPreSoftmaxInputs, const MlirValue *preSoftmaxElemWiseInputs,
MlirType resultType, MlirType lseType, MlirType softmaxType,
MlirRegion preSoftmaxBody, uint32_t features, MlirValue currentSeqLen,
MlirValue prefixOffset, int32_t splitKV, int32_t slidingWindowSize);

#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===-- MIGraphXAttentionToRock.h -------------------------------*- C++ -*-===//
//
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (c) 2026 Advanced Micro Devices
//
// Pass declaration for lowering migraphx.attention to rock.attention.
// See MIGraphXAttentionToRock.cpp for the polarity contract with the
// host-side AttentionDecompose pattern.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MIGRAPHXATTENTIONTOROCK_H
#define MLIR_CONVERSION_MIGRAPHXATTENTIONTOROCK_H

#include "mlir/Pass/Pass.h"

namespace mlir {

#define GEN_PASS_DECL_MIGRAPHXATTENTIONTOROCKPASS
#include "mlir/Conversion/RocMLIRPasses.h.inc"

} // namespace mlir

#endif // MLIR_CONVERSION_MIGRAPHXATTENTIONTOROCK_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/RocMLIRPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Conversion/EmulateFp8ExtTrunc/EmulateFp8ExtTrunc.h"
#include "mlir/Conversion/FixTosaCastRounding/FixTosaCastRounding.h"
#include "mlir/Conversion/LinalgToRock/LinalgToRock.h"
#include "mlir/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.h"
#include "mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h"
#include "mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h"
#include "mlir/Conversion/Passes.h"
Expand Down
43 changes: 39 additions & 4 deletions mlir/include/mlir/Conversion/RocMLIRPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,23 @@ def MIGraphXToTosaPass : Pass<"migraphx-to-tosa", "::mlir::func::FuncOp"> {
Pass that converts MIGraphX operations to TOSA operations.
}];

let dependentDialects = [
"func::FuncDialect",
"tosa::TosaDialect",
"mhal::MHALDialect",
// The `arith`, `rock`, and `tensor` dialects appear in the conversion
// target's legality rules (addLegalDialect / addDynamicallyLegalDialect /
// markOpRecursivelyLegal<rock::AttentionOp>) so that rock.attention and
// any pre-existing arith/tensor ops survive the partial conversion. They
// need to be loaded by the time the pass runs even when the IR being
// converted doesn't yet contain ops in those dialects (e.g. an off-tree
// tool that schedules just this pass before MIGraphXAttentionToRock has
// produced the rock.attention).
//
// The dialects nested inside rock.attention's region (linalg, math,
// memref, bufferization, ...) are intentionally NOT listed here: this
// pass only marks them recursively legal, never references their types
// or creates ops in them, and ops inside an in-flight rock.attention
// imply their dialect was already loaded at parse time.
let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
"mhal::MHALDialect", "rock::RockDialect",
"tensor::TensorDialect", "tosa::TosaDialect",
];
}

Expand Down Expand Up @@ -193,4 +206,26 @@ def LinalgToRockPass : Pass<"linalg-to-rock", "::mlir::func::FuncOp"> {
let dependentDialects = ["rock::RockDialect",
"bufferization::BufferizationDialect"];
}
//===----------------------------------------------------------------------===//
// MIGraphXAttentionToRock
//===----------------------------------------------------------------------===//
def MIGraphXAttentionToRockPass
: Pass<"migraphx-attention-to-rock", "::mlir::func::FuncOp"> {
let summary = "Lower migraphx.attention to rock.attention";
let description = [{
Pass that converts migraphx.attention operations directly to
rock.attention operations for GPU compilation.

Anchored to func::FuncOp for consistency with the other migraphx-side
conversion pass (MIGraphXToTosaPass) and because the pipeline already
schedules it inside a func-nested pass manager.
}];
let dependentDialects = ["arith::ArithDialect", "linalg::LinalgDialect",
"math::MathDialect", "rock::RockDialect",
"bufferization::BufferizationDialect",
"memref::MemRefDialect", "tensor::TensorDialect",
"migraphx::MIGraphXDialect",
];
}

#endif // ROCMLIR_CONVERSION_PASSES
75 changes: 75 additions & 0 deletions mlir/include/mlir/Dialect/MIGraphX/IR/AttentionUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//===- AttentionUtils.h - Shared rules for migraphx.attention ----- C++ -===//
//
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (c) 2026 Advanced Micro Devices
//
//===----------------------------------------------------------------------===//
//
// Small inline helpers that encode contracts shared by several pieces of
// migraphx.attention's lowering chain. Keeping them here means the
// verifier, the host AttentionDecompose, the GPU MIGraphXAttentionToRock
// lowering, and rocmlir-gen all derive the same answers from the same code.
//
// Anything that's only used in one place, or that requires
// path-specific inputs (e.g. expectedQKShape, which the verifier
// computes from pre-splitKV operands while the host decompose computes
// from post-splitKV-reshaped types), should stay local to that pass.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_MIGRAPHX_IR_ATTENTIONUTILS_H_
#define MLIR_MIGRAPHX_IR_ATTENTIONUTILS_H_

#include "mlir/Dialect/MIGraphX/IR/MIGraphX.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"

namespace mlir {
namespace migraphx {

/// The element type that the first GEMM (Q*K) of a migraphx.attention
/// produces, given Q's element type. For float Q the QK output stays in
/// Q's type; for integer Q the first GEMM is a quantized matmul whose
/// output is i32 (the body is then expected to dequantize that i32 to a
/// float type). Used by AttentionOp::verify, MIGraphXTransform's host
/// AttentionDecompose, and rocmlir-gen so all three derive the same QK
/// type for the same Q.
inline Type computeAttentionQKElemType(Type qElemType, MLIRContext *ctx) {
if (isa<FloatType>(qElemType))
return qElemType;
return IntegerType::get(ctx, 32);
}

/// Returns true if `op` is in the closed set of migraphx ops that
/// MIGraphXAttentionToRock::lowerMIGraphXElementwiseToScalar can lower to
/// a scalar arith / math equivalent inside a linalg.generic body. The
/// AttentionOp verifier consults this so the verifier never accepts a
/// preSoftmaxBody that the lowering would later reject; the lowering
/// itself uses the same membership rule (encoded as a dispatch table) to
/// decide what to emit.
///
/// IMPORTANT: this list and
/// MIGraphXAttentionToRock::lowerMIGraphXElementwiseToScalar must stay in
/// lock-step. Adding a new body op is a one-line change in two coupled
/// places (this function plus the lowering's dispatch table). The
/// AttentionToRockPattern body-builder asserts at runtime that any op in
/// this allowlist is also handled by the dispatcher, so divergence trips
/// the assertion (debug builds) or surfaces as a structured
/// "unsupported migraphx op in preSoftmaxBody" error (release builds).
inline bool isAllowedInPreSoftmaxBody(Operation &op) {
return isa<migraphx::AddOp, migraphx::SubOp, migraphx::MulOp, migraphx::DivOp,
migraphx::PowOp, migraphx::NegOp, migraphx::AbsOp,
migraphx::CeilOp, migraphx::FloorOp, migraphx::ExpOp,
migraphx::LogOp, migraphx::SqrtOp, migraphx::TanhOp,
migraphx::ErfOp, migraphx::RecipOp, migraphx::ReluOp,
migraphx::SigmoidOp, migraphx::WhereOp, migraphx::ConvertOp,
migraphx::DeQuantizeLinearOp>(op);
}

} // namespace migraphx
} // namespace mlir

#endif // MLIR_MIGRAPHX_IR_ATTENTIONUTILS_H_
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ namespace migraphx {} // end namespace migraphx

#include "mlir/Dialect/MIGraphX/IR/MIGraphXEnums.h.inc"

namespace mlir {
namespace migraphx {
inline bool hasAttentionFeature(std::optional<AttentionFeatures> features,
AttentionFeatures flag) {
if (!features)
return false;
return bitEnumContainsAll(*features, flag);
}
} // namespace migraphx
} // namespace mlir

#define GET_OP_CLASSES
#include "mlir/Dialect/MIGraphX/IR/MIGraphX.h.inc"

Expand Down
Loading
Loading