Skip to content

Commit 20f12ba

Browse files
committed
add E2E loweing for migraphx.attention
1 parent 1f3a188 commit 20f12ba

25 files changed

Lines changed: 1105 additions & 190 deletions

File tree

mlir/include/mlir-c/Dialect/MIGraphX.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ mlirMIGraphXAddBackendPipeline(MlirPassManager pm,
8383
/// \p queries, \p keys, \p values are the required Q, K, V operands.
8484
/// \p preSoftmaxElemWiseInputs is an array of \p numPreSoftmaxInputs additional
8585
/// operands for element-wise fusion before softmax (can be NULL if 0).
86-
/// \p resultType is the tensor type of the attention result (required).
87-
/// \p lseType is the tensor type of the optional log-sum-exp output; pass
86+
/// \p resultType is the MIXRShaped type of the attention result (required).
87+
/// \p lseType is the MIXRShaped type of the optional log-sum-exp output; pass
8888
/// a null type (via mlirTypeIsNull) to omit.
8989
/// \p softmaxType is the optional element type for softmax computation; pass
9090
/// a null type to omit.
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===-- MIGraphXAttentionToRock.h - Lower migraphx.attention to rock --*- C++
2+
//-*-===//
3+
//
4+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
5+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
// Copyright (c) 2025 Advanced Micro Devices
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#ifndef MLIR_CONVERSION_MIGRAPHXATTENTIONTOROCK_H
13+
#define MLIR_CONVERSION_MIGRAPHXATTENTIONTOROCK_H
14+
15+
#include "mlir/Pass/Pass.h"
16+
17+
namespace mlir {
18+
19+
#define GEN_PASS_DECL_MIGRAPHXATTENTIONTOROCKPASS
20+
#include "mlir/Conversion/RocMLIRPasses.h.inc"
21+
22+
} // namespace mlir
23+
24+
#endif // MLIR_CONVERSION_MIGRAPHXATTENTIONTOROCK_H

mlir/include/mlir/Conversion/RocMLIRPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/EmulateFp8ExtTrunc/EmulateFp8ExtTrunc.h"
1313
#include "mlir/Conversion/FixTosaCastRounding/FixTosaCastRounding.h"
1414
#include "mlir/Conversion/LinalgToRock/LinalgToRock.h"
15+
#include "mlir/Conversion/MIGraphXAttentionToRock/MIGraphXAttentionToRock.h"
1516
#include "mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h"
1617
#include "mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h"
1718
#include "mlir/Conversion/Passes.h"

mlir/include/mlir/Conversion/RocMLIRPasses.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,25 @@ def LinalgToRockPass : Pass<"linalg-to-rock", "::mlir::func::FuncOp"> {
193193
let dependentDialects = ["rock::RockDialect",
194194
"bufferization::BufferizationDialect"];
195195
}
196+
//===----------------------------------------------------------------------===//
197+
// MIGraphXAttentionToRock
198+
//===----------------------------------------------------------------------===//
199+
def MIGraphXAttentionToRockPass
200+
: Pass<"migraphx-attention-to-rock"> {
201+
let summary = "Lower migraphx.attention to rock.attention";
202+
let description = [{
203+
Pass that converts migraphx.attention operations directly to
204+
rock.attention operations for GPU compilation.
205+
}];
206+
let dependentDialects = [
207+
"arith::ArithDialect",
208+
"linalg::LinalgDialect",
209+
"rock::RockDialect",
210+
"bufferization::BufferizationDialect",
211+
"memref::MemRefDialect",
212+
"tensor::TensorDialect",
213+
"migraphx::MIGraphXDialect",
214+
];
215+
}
216+
196217
#endif // ROCMLIR_CONVERSION_PASSES

mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,10 @@ def MIGraphX_AttentionOp
598598
Arguments<(ins MIXRShapedOf<AttentionQKTypes>:$queries,
599599
MIXRShapedOf<AttentionQKTypes>:$keys,
600600
MIXRShapedOf<AttentionVTypes>:$values,
601-
Variadic<AnyTypeOf<[AnyMIXRShaped, AnyRankedTensor]>>:$preSoftmaxElemWiseInputs,
601+
Variadic<AnyMIXRShaped>:$preSoftmaxElemWiseInputs,
602602
OptionalAttr<TypeAttr>:$softmaxType)>,
603-
Results<(outs AnyRankedTensor:$result,
604-
Optional<AnyRankedTensor>:$lse)> {
603+
Results<(outs AnyMIXRShaped:$result,
604+
Optional<AnyMIXRShaped>:$lse)> {
605605
let summary = "Attention operation of transformer models";
606606
let description = [{
607607
Performs the operation result = SOFTMAX(preSoftmaxBody(queries * keys^T, preSoftmaxElemWiseInputs)) * values.

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_subdirectory(EmulateFp8ExtTrunc)
22
add_subdirectory(FixTosaCastRounding)
3+
add_subdirectory(MIGraphXAttentionToRock)
34
add_subdirectory(MIGraphXToTosa)
45
add_subdirectory(RocmlirCustomTosaDecompose)
56
add_subdirectory(RocmlirCustomTosaToLinalg)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
add_rocmlir_conversion_library(MLIRMIGraphXAttentionToRock
2+
MIGraphXAttentionToRock.cpp
3+
4+
DEPENDS
5+
RocMLIRConversionPassIncGen
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArithDialect
9+
MLIRIR
10+
MLIRFuncDialect
11+
MLIRLinalgDialect
12+
MLIRMIGraphXDialect
13+
MLIRMemRefDialect
14+
MLIRRockOps
15+
MLIRPass
16+
MLIRTensorDialect
17+
MLIRBufferizationDialect
18+
MLIRTransformUtils
19+
MLIRSupport
20+
)

0 commit comments

Comments
 (0)