diff --git a/external/mlir-hal/include/mlir/Conversion/MHALPasses.h b/external/mlir-hal/include/mlir/Conversion/MHALPasses.h index 14980dee3445..5f960c45921f 100644 --- a/external/mlir-hal/include/mlir/Conversion/MHALPasses.h +++ b/external/mlir-hal/include/mlir/Conversion/MHALPasses.h @@ -9,7 +9,6 @@ #ifndef MHAL_CONVERSION_PASSES_H #define MHAL_CONVERSION_PASSES_H -#include "mlir/Conversion/MHALToCPU/MHALToCPU.h" #include "mlir/Conversion/MHALToGPU/MHALToGPU.h" #include "mlir/Pass/PassRegistry.h" diff --git a/external/mlir-hal/include/mlir/Conversion/MHALPasses.td b/external/mlir-hal/include/mlir/Conversion/MHALPasses.td index ba84fa45faaa..10e8875c66d3 100644 --- a/external/mlir-hal/include/mlir/Conversion/MHALPasses.td +++ b/external/mlir-hal/include/mlir/Conversion/MHALPasses.td @@ -16,17 +16,9 @@ include "mlir/Pass/PassBase.td" //===----------------------------------------------------------------------===// def ConvertMHALToGPUPass : Pass<"convert-mhal-to-gpu", "ModuleOp"> { - let summary = "Convert the mhal.launch operations to gpu.launch_func"; + let summary = "Lower bufferized func.call to GPU kernels (mhal.targets) to " + "gpu.launch_func"; let dependentDialects = ["gpu::GPUDialect"]; } -//===----------------------------------------------------------------------===// -// MHALToCPU -//===----------------------------------------------------------------------===// - -def ConvertMHALToCPUPass : Pass<"convert-mhal-to-cpu", "ModuleOp"> { - let summary = "Convert the mhal.launch operations to func.call"; - let dependentDialects = ["func::FuncDialect"]; -} - #endif // MHAL_CONVERSION_PASSES diff --git a/external/mlir-hal/include/mlir/Conversion/MHALToCPU/MHALToCPU.h b/external/mlir-hal/include/mlir/Conversion/MHALToCPU/MHALToCPU.h deleted file mode 100644 index 87e0955ddaab..000000000000 --- a/external/mlir-hal/include/mlir/Conversion/MHALToCPU/MHALToCPU.h +++ /dev/null @@ -1,22 +0,0 @@ -//===- MHALToCPU.h - Convert MHAL to CPU dialect ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_CONVERSION_MHALTOCPU_MHALTOCPU_H -#define MLIR_CONVERSION_MHALTOCPU_MHALTOCPU_H - -#include - -namespace mlir { -class Pass; - -#define GEN_PASS_DECL_CONVERTMHALTOCPUPASS -#include "mlir/Conversion/MHALPasses.h.inc" - -} // namespace mlir - -#endif // MLIR_CONVERSION_MHALTOCPU_MHALTOCPU_H diff --git a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHAL.h b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHAL.h index c0a68855e249..5f1da8a5f04e 100644 --- a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHAL.h +++ b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHAL.h @@ -13,11 +13,8 @@ #ifndef MLIR_MHAL_IR_MHAL_H_ #define MLIR_MHAL_IR_MHAL_H_ -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/CallInterfaces.h" //===----------------------------------------------------------------------===// // MHAL Dialect @@ -25,8 +22,6 @@ #include "mlir/Dialect/MHAL/IR/MHALOpsDialect.h.inc" #include "mlir/Dialect/MHAL/IR/MHALTypes.h" -#define GET_OP_CLASSES -#include "mlir/Dialect/MHAL/IR/MHALOps.h.inc" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/MHAL/IR/MHALAttrDefs.h.inc" diff --git a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALBase.td b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALBase.td index 4911c7092bbe..3ca750da0892 100644 --- a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALBase.td +++ b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALBase.td @@ -15,9 +15,8 @@ def MHALDialect : Dialect { let name = "mhal"; let cppNamespace = "::mlir::mhal"; - let summary = "Types and operations for mhal dialect"; + let summary = "Attributes for the MHAL dialect"; - let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; let extraClassDeclaration = [{ diff --git a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALOps.td b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALOps.td index f3e65a9976ff..15902c6ec390 100644 --- a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALOps.td +++ b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALOps.td @@ -1,4 +1,4 @@ -//===- MHALOps.td - MHAL operations definition -----------*- tablegen -*-===// +//===- MHALOps.td - MHAL dialect TableGen entry point ----*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,10 @@ // //===----------------------------------------------------------------------===// // -// This is the operation definition file for MHAL dialect operations. +// TableGen entry point referenced by `add_mlir_dialect(MHALOps mhal)`. The +// MHAL dialect currently exposes only attributes (e.g. mhal.targets, +// mhal.write_access) -- no ops, no types. The dialect class itself is defined +// in MHALBase.td; attribute definitions live in MHALAttrDefs.td. // //===----------------------------------------------------------------------===// @@ -15,136 +18,5 @@ include "mlir/Dialect/MHAL/IR/MHALBase.td" include "mlir/Dialect/MHAL/IR/MHALAttrDefs.td" -include "mlir/Interfaces/CallInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/IR/SymbolInterfaces.td" -include "mlir/IR/OpAsmInterface.td" - - -//===----------------------------------------------------------------------===// -// MHAL op definitions -//===----------------------------------------------------------------------===// - -// Base class for the operation in this dialect -class MHAL_Op traits = []> : - Op; - -def MHAL_LaunchOp : - MHAL_Op<"launch", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - AttrSizedOperandSegments]> { - let summary = "asynchronous launch operation"; - let description = [{ - The `mhal.launch` operation encapsulates a call to a `kernel` func that - may run mhalhronously from the caller. An `mhal.token` is returned to - maintain ordered dependent execution. Subsequent instructions that depend - on results from the `mhal.launch` must either be dependent on the token - or be preceded by an `mhal.await` of that token. - - The actual concurrency semantics depends on the dialect lowering to the - executable format. Fully sequential execution ("compute0" completes before - "compute1" starts) is a completely legal execution. - - Because concurrent execution is not guaranteed, it is undefined behavior - to create an implicit dependency from "compute1" to "compute0" (e.g. via - shared global state). All dependencies must be made explicit with mhal - launch arguments (i.e. `mhal.token`). - - `mhal.launch` operation takes `mhal.token` dependencies and operands - separately, and starts execution of the kernel function only when all - tokens become ready. In contrast to `mhal.execute` this operation does - not depend on operand values as `mhal.value` requires. - - Example: - - ```mlir - %dependency = ... : !mhal.token - - %token, %results = - mhal.launch @compute0 [%dependency] (%value0, %value1) : - (some.type, some.type) -> some.type - - %1 = "compute1"(...) : !some.type - ``` - - In the example above mhalhronous execution starts only after dependency - token becomes ready. - }]; - - let arguments = (ins FlatSymbolRefAttr:$callee, - Variadic:$dependencies, - Variadic:$launchOperands); - - let results = (outs MHAL_TokenType:$token, - Variadic:$results); - - let builders = [ - OpBuilder<(ins "func::FuncOp":$kernelFunc, "ValueRange":$dependencies, - "ValueRange":$kernelOperands)>, - OpBuilder<(ins "FlatSymbolRefAttr":$callee, "TypeRange":$results, - CArg<"ValueRange", "{}">:$launchOperands), [{ - $_state.addOperands(launchOperands); - $_state.addAttribute("callee", callee); - $_state.addTypes(results); - auto operandSegmentSizes = $_builder.getDenseI32ArrayAttr( - {0, static_cast(launchOperands.size())}); - $_state.addAttribute(getOperandSegmentSizesAttrName($_state.name), - operandSegmentSizes); - - }]>, - ]; - - let extraClassDeclaration = [{ - Operation::result_range getCallResults(); - Operation::result_type_range getCallResultTypes(); - void updateSegmentSizes(MLIRContext *); - }]; - - let hasVerifier = 1; - let assemblyFormat = [{ - $callee (` ` `[` $dependencies^ `]`)? `(` $launchOperands `)` attr-dict - `:` `(` type($launchOperands) `)` (`->` type($results)^)? - }]; - -} - -def MHAL_AwaitOp : MHAL_Op<"await"> { - let summary = "waits for the argument to become ready"; - let description = [{ - The `mhal.await` operation waits until the argument becomes ready, and for - - Example: - - ```mlir - %0 = ... : !mhal.token - mhal.await %0 : !mhal.token - - ``` - }]; - - let arguments = (ins MHAL_TokenType:$operand); - let results = (outs Optional:$result); - - let skipDefaultBuilders = 1; - let hasVerifier = 1; - - let builders = [ - OpBuilder<(ins "Value":$operand, - CArg<"ArrayRef", "{}">:$attrs)>, - ]; - - let extraClassDeclaration = [{ - std::optional getResultType() { - if (getResultTypes().empty()) return std::nullopt; - return getResultTypes()[0]; - } - }]; - - let assemblyFormat = [{ - $operand `:` custom( - type($operand), type($result) - ) attr-dict - }]; -} #endif // MHAL_OPS diff --git a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALTypes.td b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALTypes.td index 495bb6a91eba..72f7183c699d 100644 --- a/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALTypes.td +++ b/external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALTypes.td @@ -18,22 +18,9 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/Dialect/MHAL/IR/MHALBase.td" //===----------------------------------------------------------------------===// -// MHAL Types +// MHAL Enums //===----------------------------------------------------------------------===// -class MHAL_Type : TypeDef { - let mnemonic = typeMnemonic; -} - -def MHAL_TokenType : MHAL_Type<"Token", "token"> { - let summary = "mhal token type"; - let description = [{ - `mhal.token` is a type returned by mhalhronous operations, and it becomes - `available` when the mhalhronous operations that created it is completed. - }]; -} - class MHAL_I32Enum cases> : I32EnumAttr { let cppNamespace = "::mlir::mhal"; diff --git a/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.h b/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.h deleted file mode 100644 index d636ee5c7331..000000000000 --- a/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.h +++ /dev/null @@ -1,22 +0,0 @@ -//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_MHAL_BUFFERIZABLEOPINTERFACEIMPL_H -#define MLIR_DIALECT_MHAL_BUFFERIZABLEOPINTERFACEIMPL_H - -#include "mlir/IR/DialectRegistry.h" - -namespace mlir { -namespace mhal { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace mhal -} // namespace mlir - -#endif // MLIR_DIALECT_MHAL_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.h b/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.h index 69fac93c44b0..60df71b0c5d7 100644 --- a/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.h +++ b/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.h @@ -45,12 +45,6 @@ namespace mhal { void populateMHalNarrowTypeEmulationConversions( arith::NarrowTypeEmulationConverter &typeConverter); -/// Adds patterns for rewriting `mhal.launch` ops to `patterns` that replace -/// 4-bit (or other narrow pattern of two) memrefs to 8-bit ones. -void populateMHalNarrowTypeEmulationBoundaryPatterns( - arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns); - /// Adds patterns that handle `extract_strided_metadata` ops targetting the /// `builtin.unrealized_conversion_cast` operations that the type conversion /// process introduces to prevent dialect conversion from failing due to stray diff --git a/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.td b/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.td index 058a7d7b9dd1..11cbe888831c 100644 --- a/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.td +++ b/external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.td @@ -22,11 +22,11 @@ def MHALBufferizePass : Pass<"mhal-bufferize", "func::FuncOp"> { } def MHalEmulateNarrowTypePass : Pass<"mhal-emulate-narrow-type", "func::FuncOp"> { - let summary = "Emulate memrefs of 4-bit integers (or 2-bit, maybe)"; + let summary = "Emulate memrefs of 4-bit integers (or 2-bit, maybe)"; let description = [{ Coordinates the rewrite patterns present in upstream MLIR for supporting - i4 memrefs by rewriting them to linear i8 ones. Also rewrite `mhal.launch` - just like `func.call` gets rewritten. + i4 memrefs by rewriting them to linear i8 ones, and rewrites `func.call` + boundaries accordingly. Note that this pass is meant to target the host/test code, and so doesn't handle vector loads/stores because they won't have been produced by the host runner. diff --git a/external/mlir-hal/include/mlir/InitMHALDialects.h b/external/mlir-hal/include/mlir/InitMHALDialects.h index ecaac586d5fb..554aa68c760f 100644 --- a/external/mlir-hal/include/mlir/InitMHALDialects.h +++ b/external/mlir-hal/include/mlir/InitMHALDialects.h @@ -16,7 +16,6 @@ // MHAL includes #include "mlir/Dialect/MHAL/IR/MHAL.h" -#include "mlir/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.h" namespace mlir { @@ -24,9 +23,6 @@ namespace mlir { inline void registerMHALDialects(DialectRegistry ®istry) { // Register MHAL specific dialects registry.insert(); - - // Register bufferization hooks for mhal interfaces - mhal::registerBufferizableOpInterfaceExternalModels(registry); } } // namespace mlir diff --git a/external/mlir-hal/lib/Conversion/CMakeLists.txt b/external/mlir-hal/lib/Conversion/CMakeLists.txt index 58bba89b7793..8ddf4712ee79 100644 --- a/external/mlir-hal/lib/Conversion/CMakeLists.txt +++ b/external/mlir-hal/lib/Conversion/CMakeLists.txt @@ -1,2 +1 @@ add_subdirectory(MHALToGPU) -add_subdirectory(MHALToCPU) diff --git a/external/mlir-hal/lib/Conversion/MHALToCPU/CMakeLists.txt b/external/mlir-hal/lib/Conversion/MHALToCPU/CMakeLists.txt deleted file mode 100644 index 716a7f28440c..000000000000 --- a/external/mlir-hal/lib/Conversion/MHALToCPU/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_mlir_conversion_library(MLIRMHALToCPU - MHALToCPU.cpp - - ADDITIONAL_HEADER_DIRS - ${MHAL_MAIN_INCLUDE_DIR}/mlir/Conversion/MHALToCPU - - DEPENDS - MHALConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRMHAL - MLIRLLVMDialect - MLIRTransforms - ) diff --git a/external/mlir-hal/lib/Conversion/MHALToCPU/MHALToCPU.cpp b/external/mlir-hal/lib/Conversion/MHALToCPU/MHALToCPU.cpp deleted file mode 100644 index b5d625692e47..000000000000 --- a/external/mlir-hal/lib/Conversion/MHALToCPU/MHALToCPU.cpp +++ /dev/null @@ -1,113 +0,0 @@ -//===- MHALToCPU.cpp - Convert MHAL to CPU dialect --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/MHALToCPU/MHALToCPU.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MHAL/IR/MHAL.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#define DEBUG_TYPE "convert-mhal-to-cpu" - -namespace mlir { -#define GEN_PASS_DEF_CONVERTMHALTOCPUPASS -#include "mlir/Conversion/MHALPasses.h.inc" -} // namespace mlir - -using namespace mlir; -using namespace mlir::mhal; - -//===----------------------------------------------------------------------===// -// Convert MHAL dialect types to CPU types. -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// Convert mhal.launch ops with 'cpu' target to cpu.launch_func ops with -// required memory staging. -//===----------------------------------------------------------------------===// - -namespace { -// Helper to pull out the called func -static std::optional getCalledFunc(mhal::LaunchOp op) { - CallOpInterface callIf(op); - if (auto *callable = callIf.resolveCallable()) { - if (auto func = dyn_cast(callable)) - return func; - } - - return std::nullopt; -} - -struct LaunchRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhal::LaunchOp op, - PatternRewriter &rw) const override { - Location loc = op.getLoc(); - - assert(op->getNumResults() == 1); // only 1 mhal.token - - if (auto func = getCalledFunc(op)) { - // Replace the original `async.execute` with a call to outlined - // function. - func::CallOp::create(rw, loc, *func, op.getArgOperands()); - - Value empty; - op->replaceAllUsesWith(ValueRange(empty)); - op->erase(); - - return success(); - } - return rw.notifyMatchFailure(op, "func not found"); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Remove all mhal.await ops -//===----------------------------------------------------------------------===// - -namespace { -struct AwaitRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhal::AwaitOp op, - PatternRewriter &rw) const override { - rw.eraseOp(op); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// - -namespace { -struct ConvertMHALToCPUPass - : public impl::ConvertMHALToCPUPassBase { - void runOnOperation() override; -}; -} // namespace - -void ConvertMHALToCPUPass::runOnOperation() { - auto op = getOperation(); - MLIRContext *ctx = op->getContext(); - - // Convert mhal.launch to func.call ops, remove all mhal.await ops - RewritePatternSet patterns(ctx); - patterns.add(ctx); - patterns.add(ctx); - - if (failed(applyPatternsGreedily(op, std::move(patterns)))) - signalPassFailure(); - - op.walk([](func::FuncOp f) { f->removeAttr("mhal.targets"); }); -} diff --git a/external/mlir-hal/lib/Conversion/MHALToGPU/MHALToGPU.cpp b/external/mlir-hal/lib/Conversion/MHALToGPU/MHALToGPU.cpp index 7e90eecc1458..66f4196541c2 100644 --- a/external/mlir-hal/lib/Conversion/MHALToGPU/MHALToGPU.cpp +++ b/external/mlir-hal/lib/Conversion/MHALToGPU/MHALToGPU.cpp @@ -13,11 +13,11 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MHAL/IR/MHAL.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #define DEBUG_TYPE "convert-mhal-to-gpu" @@ -27,44 +27,18 @@ namespace mlir { } // namespace mlir using namespace mlir; -using namespace mlir::mhal; //===----------------------------------------------------------------------===// -// Convert MHAL dialect types to GPU types. +// Lower bufferized host calls to GPU kernels (mhal.targets). //===----------------------------------------------------------------------===// -namespace { -/// MHALGPUTypeConverter only converts types from the MHAL dialect to -/// the corresponding GPU type and does not convert any other types. -class MHALGPUTypeConverter : public TypeConverter { -public: - MHALGPUTypeConverter() { - addConversion([](Type type) { return type; }); - addConversion([](TokenType type) { - return gpu::AsyncTokenType::get(type.getContext()); - }); - } -}; -} // namespace - -// Helper to pull out the called func -static std::optional getCalledFunc(mhal::LaunchOp op) { - CallOpInterface callIf(op); - if (auto *callable = callIf.resolveCallable()) { - if (auto func = dyn_cast(callable)) - return func; - } - - return std::nullopt; -} - -// Get target{gpu} attribute from called func -static std::optional getGPUTarget(mhal::LaunchOp op) { - auto func = getCalledFunc(op); - if (!func.has_value() || func->getNumResults() != 0) +/// mhal.targets[gpu] lookup for a kernel function symbol (bufferized +/// func.call). +static std::optional getGPUTarget(func::FuncOp func) { + if (func.getNumResults() != 0) return std::nullopt; - auto attr = (*func)->template getAttrOfType("mhal.targets"); + auto attr = func->getAttrOfType("mhal.targets"); if (!attr) return std::nullopt; @@ -76,246 +50,181 @@ static std::optional getGPUTarget(mhal::LaunchOp op) { return std::nullopt; } -//===----------------------------------------------------------------------===// -// Convert mhal.launch ops with 'gpu' target to gpu.launch_func ops with -// required memory staging. -//===----------------------------------------------------------------------===// - -namespace { -struct LaunchRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - Value makeWait(OpBuilder b, Location loc, ArrayRef deps = {}) const { - auto tokenType = b.getType(); - return gpu::WaitOp::create(b, loc, tokenType, deps).getAsyncToken(); +/// Lower a bufferized func.call to a GPU kernel (mhal.targets) to +/// gpu.launch_func with staging; synchronize with gpu.wait and erase the +/// call. +static LogicalResult lowerKernelCallToGpu(PatternRewriter &rw, func::CallOp op, + func::FuncOp func) { + Location loc = op.getLoc(); + auto module = op->getParentOfType(); + MLIRContext *ctx = module.getContext(); + + auto kernelPkg = getGPUTarget(func); + if (!kernelPkg.has_value()) + return rw.notifyMatchFailure(op, "no gpu target"); + + auto targetObj = kernelPkg->getObject(); + auto binary = targetObj.getBinary(); + auto launchDims = kernelPkg->getLaunchDims(); + if (launchDims.size() != 2) + return rw.notifyMatchFailure(op, "bad launch dims"); + auto gridSize = launchDims[0]; + auto blockSize = launchDims[1]; + + FunctionOpInterface funcIF(func); + auto funcName = funcIF.getName(); + std::string binaryName = (funcName + "_module").str(); + + auto binaryOp = module.lookupSymbol(binaryName); + if (!binaryOp) { + OpBuilder b(ctx); + binaryOp = gpu::BinaryOp::create(b, loc, binaryName, nullptr, + ArrayRef({binary})); + + SymbolTable symbolTable(module); + symbolTable.insert(binaryOp); } - template bool isOnDevice(const T &oprUsers) const { - for (auto opUse : oprUsers) { - auto gpuLaunch = dyn_cast(opUse); - auto launch = dyn_cast(opUse); - // assumes the same GPU - if (!gpuLaunch && !(launch && getGPUTarget(launch).has_value())) - return false; + auto makeWait = [&](OpBuilder &b, Location l, ArrayRef deps) { + auto tt = b.getType(); + return gpu::WaitOp::create(b, l, tt, deps).getAsyncToken(); + }; + + auto userOnDevice = [&](Operation *userOp) { + if (isa(userOp)) + return true; + if (auto call = dyn_cast(userOp)) { + if (auto callee = module.lookupSymbol(call.getCallee())) + return getGPUTarget(callee).has_value(); } - return true; - } + return false; + }; - Value moveMemory(OpBuilder b, mhal::LaunchOp launchOp, Value opr, - uint32_t fidx, bool writeAccess, - llvm::SmallVector ©BackOprs, - llvm::SmallVector &asyncDeps) const { + auto moveMemory = [&](Operation *anchor, Value opr, uint32_t fidx, + bool writeAccess, + llvm::SmallVector ©BackOprs, + llvm::SmallVector &asyncDeps) -> Value { if (auto gpuAllocOp = opr.getDefiningOp()) { - // TEST: convergence or multi-input?? - assert(isOnDevice(opr.getUsers())); + assert(llvm::all_of(opr.getUsers(), userOnDevice)); asyncDeps.push_back(gpuAllocOp.getAsyncToken()); return opr; } - - Location loc = opr.getLoc(); + Location oloc = opr.getLoc(); + OpBuilder b = rw; auto tokenType = b.getType(); auto oprAllocOp = opr.getDefiningOp(); - auto bAlloc = b; + OpBuilder bAlloc = b; if (oprAllocOp) bAlloc.setInsertionPointAfter(oprAllocOp); - - Value allocWait = makeWait(bAlloc, loc); - Type gpuMemType = opr.getType(); + Value allocWait = makeWait(bAlloc, oloc, {}); auto dst = - gpu::AllocOp::create(bAlloc, loc, gpuMemType, tokenType, + gpu::AllocOp::create(bAlloc, oloc, opr.getType(), tokenType, ValueRange{allocWait}, ValueRange{}, ValueRange{}); Value dstMem = dst.getResult(0); Value dstToken = dst.getResult(1); - - auto makeCopy = [&]() { - // always copy to device, even if it's read_access only - // this way we initialize with whatever was provided by the user - auto memcpyToken = gpu::MemcpyOp::create( - b, loc, tokenType, ValueRange{dstToken}, dstMem, opr); - dstToken = memcpyToken.getResult(0); - if (writeAccess) { - // copy from device + auto runCopy = [&] { + dstToken = gpu::MemcpyOp::create(b, oloc, tokenType, ValueRange{dstToken}, + dstMem, opr) + .getResult(0); + if (writeAccess) copyBackOprs[fidx] = oprAllocOp ? opr : dstMem; - } }; - if (oprAllocOp) { - // if alloc, convert to gpu.alloc and gpu.memcpy's - SmallVector oprUsers(opr.getUsers()); - if (isOnDevice(oprUsers)) { + if (llvm::all_of(opr.getUsers(), userOnDevice)) { opr.replaceAllUsesWith(dstMem); } else { - // substitute - launchOp->replaceUsesOfWith(opr, dstMem); - makeCopy(); + anchor->replaceUsesOfWith(opr, dstMem); + runCopy(); } - } else - makeCopy(); - + } else { + runCopy(); + } asyncDeps.push_back(dstToken); return dstMem; - } - - LogicalResult matchAndRewrite(mhal::LaunchOp op, - PatternRewriter &rw) const override { - Location loc = op.getLoc(); - auto caller = op->getParentOfType(); - auto module = caller->getParentOfType(); - auto *ctx = module.getContext(); - - assert(op->getNumResults() == 1); // only 1 mhal.token - - // 1. get target{gpu} attribute from func - - auto kernelPkg = getGPUTarget(op); - if (!kernelPkg.has_value()) - return rw.notifyMatchFailure(op, "no gpu target"); - - auto targetObj = kernelPkg->getObject(); - auto binary = targetObj.getBinary(); - auto launchDims = kernelPkg->getLaunchDims(); - if (launchDims.size() != 2) - return rw.notifyMatchFailure(op, "bad launch dims"); - auto gridSize = launchDims[0]; - auto blockSize = launchDims[1]; - - auto func = *getCalledFunc(op); - Location floc = func.getLoc(); - - // 2. re-materialize gpu.binary @_module [#gpu.object<...>] - - FunctionOpInterface funcIF(func); - auto funcName = funcIF.getName(); - std::string binaryName = (funcName + "_module").str(); - - auto binaryOp = module.lookupSymbol(binaryName); - if (!binaryOp) { - OpBuilder b(ctx); - binaryOp = gpu::BinaryOp::create(b, floc, binaryName, nullptr, - ArrayRef({binary})); - - SymbolTable symbolTable(module); - symbolTable.insert(binaryOp); - } - - // 3. create substitute gpu.launch_func - // %15 = gpu.wait async - // %16 = gpu.launch_func async [%15] @test_fusion_module::@test_fusion - // blocks in (%c900, %c1, %c1) threads in (%c256, %c1, %c1) - // dynamic_shared_memory_size %c0_i32 args(%4 : memref<128x32x32x8xf32>, - // %9 : memref<128x3x3x8xf32>, %14 : memref<128x30x30x128xf32>) - - auto tokenType = rw.getType(); - - Value oneIdx = rw.createOrFold(loc, 1); - Value blockSizeIdx = - rw.createOrFold(loc, blockSize); - Value gridSizeIdx = rw.createOrFold(loc, gridSize); - Value dynamicSharedMemorySize; - - // async dependencies - auto operands = op->getOperands(); - llvm::SmallVector asyncDeps; - llvm::SmallVector gpuOperands; - size_t diff = operands.size() - func.getNumArguments(); - size_t i = 0; - if (diff > 0) { - for (; i < diff; ++i) - asyncDeps.push_back(operands[i]); - } else - assert(diff == 0); - - SmallVector copyBackOprs(func.getNumArguments(), Value()); - for (; i < operands.size(); ++i) { - auto fidx = i - diff; - Value opr = operands[i]; - // move input memories to GPU - if (isa(opr.getType())) { - bool writeAccess{ - func.getArgAttr(fidx, mhal::MHALDialect::getWriteAccessAttrName())}; - opr = - moveMemory(rw, op, opr, fidx, writeAccess, copyBackOprs, asyncDeps); - } - gpuOperands.push_back(opr); + }; + + auto tokenType = rw.getType(); + Value oneIdx = rw.createOrFold(loc, 1); + Value blockSizeIdx = rw.createOrFold(loc, blockSize); + Value gridSizeIdx = rw.createOrFold(loc, gridSize); + Value dynamicSharedMemorySize; + + auto operands = op->getOperands(); + llvm::SmallVector asyncDeps; + llvm::SmallVector gpuOperands; + + SmallVector copyBackOprs(func.getNumArguments(), Value()); + for (size_t i = 0; i < operands.size(); ++i) { + Value opr = operands[i]; + if (isa(opr.getType())) { + bool writeAccess{ + func.getArgAttr(i, mhal::MHALDialect::getWriteAccessAttrName())}; + opr = moveMemory(op, opr, i, writeAccess, copyBackOprs, asyncDeps); } + gpuOperands.push_back(opr); + } - // The gpu.launch_func requires 1 and only 1 token - if (asyncDeps.empty()) - // There must be at least 1 token - asyncDeps.push_back(makeWait(rw, loc)); - else if (asyncDeps.size() > 1) { - // Consolidate to 1 token - auto launchWait = makeWait(rw, loc, asyncDeps); - asyncDeps = {launchWait}; + if (asyncDeps.empty()) + asyncDeps.push_back(makeWait(rw, loc, {})); + else if (asyncDeps.size() > 1) + asyncDeps = {makeWait(rw, loc, asyncDeps)}; + + auto gpuLaunchOp = gpu::LaunchFuncOp::create( + rw, loc, + SymbolRefAttr::get(ctx, binaryName, + {FlatSymbolRefAttr::get(ctx, funcName)}), + gpu::KernelDim3{gridSizeIdx, oneIdx, oneIdx}, + gpu::KernelDim3{blockSizeIdx, oneIdx, oneIdx}, dynamicSharedMemorySize, + gpuOperands, tokenType, ValueRange(asyncDeps)); + Value token = gpuLaunchOp->getResult(0); + + // Insert gpu.memcpy for results + SmallVector tokens; + for (auto pair : llvm::enumerate(copyBackOprs)) { + if (auto gpuMem = pair.value()) { + auto dst = operands[pair.index()]; + if (gpuMem.getDefiningOp()) + std::swap(gpuMem, dst); + auto memcpy = gpu::MemcpyOp::create(rw, loc, tokenType, ValueRange{token}, + dst, gpuMem); + tokens.push_back(memcpy.getResult(0)); } + } - // Make gpu.launch_func - auto gpuLaunchOp = gpu::LaunchFuncOp::create( - rw, loc, - SymbolRefAttr::get(getContext(), binaryName, - {FlatSymbolRefAttr::get(getContext(), funcName)}), - gpu::KernelDim3{gridSizeIdx, oneIdx, oneIdx}, - gpu::KernelDim3{blockSizeIdx, oneIdx, oneIdx}, dynamicSharedMemorySize, - gpuOperands, tokenType, ValueRange(asyncDeps)); - Value token = gpuLaunchOp->getResult(0); - - // Insert gpu.memcpy for results - SmallVector tokens; - for (auto pair : llvm::enumerate(copyBackOprs)) { - if (auto gpuMem = pair.value()) { - auto dst = operands[diff + pair.index()]; - if (gpuMem.getDefiningOp()) - std::swap(gpuMem, dst); - auto memcpy = gpu::MemcpyOp::create(rw, loc, tokenType, - ValueRange{token}, dst, gpuMem); - tokens.push_back(memcpy.getResult(0)); - } - } - - // Consolidate tokens for replacement of mhal.launch - if (tokens.size() > 1) { - // insert gpu.wait - token = makeWait(rw, loc, tokens); - } else if (tokens.size() == 1) - token = tokens[0]; - - rw.replaceOp(op, {token}); - - module->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), - rw.getUnitAttr()); + if (tokens.size() > 1) + token = makeWait(rw, loc, tokens); + else if (tokens.size() == 1) + token = tokens[0]; - return success(); - } -}; -} // namespace + gpu::WaitOp::create(rw, loc, Type(), token); + rw.eraseOp(op); -//===----------------------------------------------------------------------===// -// Convert mhal.await to the corresponding GPU API call. -//===----------------------------------------------------------------------===// + module->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), + rw.getUnitAttr()); + return success(); +} namespace { -struct AwaitRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Bufferized func.call to a kernel with mhal.targets (e.g. +/// clone-harness). +struct KernelFuncCallRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhal::AwaitOp op, + LogicalResult matchAndRewrite(func::CallOp op, PatternRewriter &rw) const override { - auto tokenType = rw.getType(); - Value input = op->getOperand(0); - if (input.getType() == tokenType) { - // mhal.await with token type should never have a result type - assert(op.getResultType() == std::nullopt); - gpu::WaitOp::create(rw, op.getLoc(), Type(), input); - rw.eraseOp(op); - return success(); - } - - return rw.notifyMatchFailure(op, "no gpu token"); + if (op.getNumResults() != 0) + return rw.notifyMatchFailure(op, + "expected bufferized call (zero results)"); + auto func = op->getParentOfType().lookupSymbol( + op.getCallee()); + if (!func || !getGPUTarget(func).has_value()) + return rw.notifyMatchFailure(op, "callee has no mhal.targets[gpu]"); + assert(op->getNumOperands() == static_cast(func.getNumArguments())); + return lowerKernelCallToGpu(rw, op, func); } }; } // namespace -//===----------------------------------------------------------------------===// - namespace { struct ConvertMHALToGPUPass : public impl::ConvertMHALToGPUPassBase { @@ -327,23 +236,11 @@ void ConvertMHALToGPUPass::runOnOperation() { auto op = getOperation(); MLIRContext *ctx = op->getContext(); - { - // Convert mhal.launch to gpu.launch if mhal.targets[gpu] exists - RewritePatternSet patterns(ctx); - patterns.add(ctx); + RewritePatternSet patterns(ctx); + patterns.add(ctx); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) - signalPassFailure(); - } - - { - // Convert mhal.await to gpu.wait if has gpu.tokens - RewritePatternSet patterns(ctx); - patterns.add(ctx); - - if (failed(applyPatternsGreedily(op, std::move(patterns)))) - signalPassFailure(); - } + if (failed(applyPatternsGreedily(op, std::move(patterns)))) + signalPassFailure(); op.walk([](func::FuncOp f) { f->removeAttr("mhal.targets"); }); } diff --git a/external/mlir-hal/lib/Dialect/MHAL/IR/CMakeLists.txt b/external/mlir-hal/lib/Dialect/MHAL/IR/CMakeLists.txt index febab930070f..66a7e9bc684a 100644 --- a/external/mlir-hal/lib/Dialect/MHAL/IR/CMakeLists.txt +++ b/external/mlir-hal/lib/Dialect/MHAL/IR/CMakeLists.txt @@ -1,7 +1,6 @@ add_mlir_dialect_library(MLIRMHAL MHAL.cpp - MHALOps.cpp ADDITIONAL_HEADER_DIRS ${MHAL_MAIN_INCLUDE_DIR}/mlir/Dialect/MHAL diff --git a/external/mlir-hal/lib/Dialect/MHAL/IR/MHAL.cpp b/external/mlir-hal/lib/Dialect/MHAL/IR/MHAL.cpp index d320833fdca4..62f4cc958d1c 100644 --- a/external/mlir-hal/lib/Dialect/MHAL/IR/MHAL.cpp +++ b/external/mlir-hal/lib/Dialect/MHAL/IR/MHAL.cpp @@ -1,4 +1,4 @@ -//===- MHAL.cpp - MHAL MLIR Operations -----------------------------===// +//===- MHAL.cpp - MHAL MLIR Dialect ----------------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -25,15 +25,6 @@ void mhal::MHALDialect::initialize() { #define GET_ATTRDEF_LIST #include "mlir/Dialect/MHAL/IR/MHALAttrDefs.cpp.inc" >(); - addOperations< -#define GET_OP_LIST -#include "mlir/Dialect/MHAL/IR/MHALOps.cpp.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir/Dialect/MHAL/IR/MHALOpsTypes.cpp.inc" - >(); - // addInterfaces(); } //===----------------------------------------------------------------------===// @@ -255,10 +246,9 @@ void KernelPackageAttr::print(mlir::AsmPrinter &printer) const { } // namespace mlir //===----------------------------------------------------------------------===// -// TableGen'd op method definitions +// TableGen'd enum definitions //===----------------------------------------------------------------------===// -#define GET_TYPEDEF_CLASSES #include "mlir/Dialect/MHAL/IR/MHALTypes.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/external/mlir-hal/lib/Dialect/MHAL/IR/MHALOps.cpp b/external/mlir-hal/lib/Dialect/MHAL/IR/MHALOps.cpp deleted file mode 100644 index 48808eb806be..000000000000 --- a/external/mlir-hal/lib/Dialect/MHAL/IR/MHALOps.cpp +++ /dev/null @@ -1,202 +0,0 @@ -//===- MHAL.cpp - MLIR MHAL Operations ----------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/MHAL/IR/MHAL.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Interfaces/FunctionImplementation.h" -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; -using namespace mlir::mhal; - -// constexpr StringRef MHALDialect::kAllowedToBlockAttrName; - -#include "mlir/Dialect/MHAL/IR/MHALOps.cpp.inc" -#include "mlir/Dialect/MHAL/IR/MHALOpsTypes.cpp.inc" - -//===----------------------------------------------------------------------===// -/// LaunchOp -//===----------------------------------------------------------------------===// - -void LaunchOp::build(OpBuilder &builder, OperationState &result, - func::FuncOp func, ValueRange dependencies, - ValueRange operands) { - // set callee - result.addAttribute(getCalleeAttrName(result.name), SymbolRefAttr::get(func)); - - result.addOperands(dependencies); - result.addOperands(operands); - - // Add derived `operand_segment_sizes` attribute based on parsed operands. - int32_t numDependencies = dependencies.size(); - int32_t numOperands = operands.size(); - auto operandSegmentSizes = - builder.getDenseI32ArrayAttr({numDependencies, numOperands}); - result.addAttribute(getOperandSegmentSizesAttrName(result.name), - operandSegmentSizes); - - // First result is always a token, and then `resultTypes` wrapped into - // `mhal.value`. - result.addTypes({TokenType::get(result.getContext())}); - for (Type type : func.getResultTypes()) - result.addTypes(type); -} - -/// Return the callee of this operation. -CallInterfaceCallable LaunchOp::getCallableForCallee() { - return (*this)->getAttrOfType(getCalleeAttrName()); -} - -/// Set the callee for this operation. -void LaunchOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", cast(callee)); -} - -/// Return the operands passed to the callee. -MutableOperandRange LaunchOp::getArgOperandsMutable() { - return getLaunchOperandsMutable(); -} - -/// Return the operands passed to the callee. -Operation::operand_range LaunchOp::getArgOperands() { - return getLaunchOperands(); -} - -/// Return the callee results. -Operation::result_range LaunchOp::getCallResults() { - return {++result_begin(), result_end()}; -} - -/// Return the callee result types. -Operation::result_type_range LaunchOp::getCallResultTypes() { - return getResults(); -} - -/// Recompute the operand_segment_sizes attribute. -void LaunchOp::updateSegmentSizes(MLIRContext *ctx) { - auto tokenTy = TokenType::get(ctx); - int32_t numDependencies = 0; - int32_t numOperands = 0; - for (const auto &oper : getOperands()) { - if (oper.getType() == tokenTy) { - // All tokens should come first. - assert(numOperands == 0); - numDependencies++; - } else - numOperands++; - } - - auto operandSegmentSizes = - DenseI32ArrayAttr::get(ctx, {numDependencies, numOperands}); - (*this)->setAttr(getOperandSegmentSizesAttrName(), operandSegmentSizes); - - assert(!(*this)->hasAttr("result_segment_sizes")); -} - -LogicalResult LaunchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto callable = (*this)->getAttrOfType("callee"); - if (!callable) - return emitOpError("requires a 'callee' symbol reference attribute"); - - func::FuncOp func = - symbolTable.lookupNearestSymbolFrom(*this, callable); - if (!func) - return emitOpError() << "'" << callable.getValue() - << "' does not reference a valid function"; - - auto funcResultTypes = func.getResultTypes(); - // The result types should be a leading mhal.token and matching return types - // of the func. - auto resultTypes = getResultTypes(); - if (resultTypes.size() != (funcResultTypes.size() + 1)) - return emitOpError( - "requires matching result types with a leading mhal.token"); - - auto resultItr = ++resultTypes.begin(); - for (auto resType : funcResultTypes) { - if (*resultItr++ != resType) - return emitOpError("requires matching result types with func"); - } - - // Match operand types - auto funcArgumentTypes = func.getArgumentTypes(); - if (funcArgumentTypes.size() != getLaunchOperands().size()) - return emitOpError("incorrect number of operands for callee"); - - for (auto tuple : llvm::zip(getLaunchOperands(), funcArgumentTypes)) { - if (std::get<0>(tuple).getType() != std::get<1>(tuple)) - return emitOpError("requires matching operand types"); - } - - return success(); -} - -LogicalResult LaunchOp::verify() { - MLIRContext *ctx = getContext(); - auto tokenTy = TokenType::get(ctx); - - // The dependencies must be mhal.tokens - for (auto dep : getDependencies()) { - if (dep.getType() != tokenTy) - return emitOpError("requires all dependencies to be mhal.token"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -/// AwaitOp -//===----------------------------------------------------------------------===// - -void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, - ArrayRef attrs) { - result.addOperands({operand}); - result.attributes.append(attrs.begin(), attrs.end()); -} - -static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, - Type &resultType) { - if (parser.parseType(operandType)) - return failure(); - - return success(); -} - -static void printAwaitResultType(OpAsmPrinter &p, Operation *op, - Type operandType, Type resultType) { - p << operandType; -} - -LogicalResult AwaitOp::verify() { - Type argType = getOperand().getType(); - - // Awaiting on a token does not have any results. - if (isa(argType) && !getResultTypes().empty()) - return emitOpError("awaiting on a token must have empty result"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "mlir/Dialect/MHAL/IR/MHALOps.cpp.inc" - -//===----------------------------------------------------------------------===// -// TableGen'd type method definitions -//===----------------------------------------------------------------------===// - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/MHAL/IR/MHALOpsTypes.cpp.inc" diff --git a/external/mlir-hal/lib/Dialect/MHAL/Pipelines/CMakeLists.txt b/external/mlir-hal/lib/Dialect/MHAL/Pipelines/CMakeLists.txt index 7bb4886685bb..d5901d376664 100644 --- a/external/mlir-hal/lib/Dialect/MHAL/Pipelines/CMakeLists.txt +++ b/external/mlir-hal/lib/Dialect/MHAL/Pipelines/CMakeLists.txt @@ -9,7 +9,6 @@ if (MHAL_ENABLE_HOST_RUNNER) MLIRVectorToLLVMPass MLIRMHALToGPU - MLIRMHALToCPU MLIRAffineToStandard MLIRSCFToControlFlow ) diff --git a/external/mlir-hal/lib/Dialect/MHAL/Pipelines/Pipelines.cpp b/external/mlir-hal/lib/Dialect/MHAL/Pipelines/Pipelines.cpp index 5d7fb7969ee0..e3a4ea0bdad7 100644 --- a/external/mlir-hal/lib/Dialect/MHAL/Pipelines/Pipelines.cpp +++ b/external/mlir-hal/lib/Dialect/MHAL/Pipelines/Pipelines.cpp @@ -43,7 +43,6 @@ #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/Conversion/MHALToCPU/MHALToCPU.h" #include "mlir/Conversion/MHALToGPU/MHALToGPU.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" @@ -67,9 +66,9 @@ void mhal::buildPackagePipeline(OpPassManager &pm, pm.addPass(mhal::createMHALPackageTargetsPass()); } -// Runner takes an Affine/SCF program with mhal retargetable launchs -// and lowers to host LLVM runtime program. JitRunner then calls ORC -// to generate X86 binary and runs it. +// Runner takes an Affine/SCF program with bufferized func.call kernel +// invocations and lowers to host LLVM runtime program. JitRunner then calls +// ORC to generate X86 binary and runs it. void mhal::buildRunnerPipeline(OpPassManager &pm, const mhal::RunnerOptions &options) { #ifdef MHAL_ENABLE_HOST_RUNNER @@ -93,10 +92,8 @@ void mhal::buildRunnerPipeline(OpPassManager &pm, // Make gpu ops async if they didn't come from the async world pm.addNestedPass(createGpuAsyncRegionPass()); - // Target mhal.launch to gpu.launch_func + // Lower bufferized GPU kernel func.call to gpu.launch_func pm.addPass(createConvertMHALToGPUPass()); - // Target remaining mhal.launch to cpu.call - pm.addPass(createConvertMHALToCPUPass()); pm.addPass(createAsyncParallelForPass()); auto &funcPm2 = pm.nest(); diff --git a/external/mlir-hal/lib/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.cpp b/external/mlir-hal/lib/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.cpp deleted file mode 100644 index 87b04950c455..000000000000 --- a/external/mlir-hal/lib/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.cpp +++ /dev/null @@ -1,194 +0,0 @@ -//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.h" - -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" -#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MHAL/IR/MHAL.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/CallInterfaces.h" - -using namespace mlir; -using namespace mlir::bufferization; -using namespace mlir::bufferization::func_ext; -using namespace mlir::mhal; -using namespace mlir::func; - -namespace mlir { -namespace mhal { -namespace { - -/// Return the FuncOp called by `callOp`. -static FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = dyn_cast(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - -/// Get FuncAnalysisState. -static const FuncAnalysisState & -getFuncAnalysisState(const AnalysisState &state) { - assert(isa(state) && "expected OneShotAnalysisState"); - auto *result = static_cast(state) - .getExtension(); - assert(result && "FuncAnalysisState does not exist"); - return *result; -} - -/// Return the state (phase) of analysis of the FuncOp. -static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, - FuncOp funcOp) { - if (!isa(state)) - return FuncOpAnalysisState::NotAnalyzed; - auto *funcState = static_cast(state) - .getExtension(); - if (!funcState) - return FuncOpAnalysisState::NotAnalyzed; - const auto &analyzedFuncOps = funcState->analyzedFuncOps; - auto it = analyzedFuncOps.find(funcOp); - if (it == analyzedFuncOps.end()) - return FuncOpAnalysisState::NotAnalyzed; - return it->second; -} - -/// Bufferization of mhal.launch. -struct LaunchOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - mhal::LaunchOp launchOp = cast(op); - auto opOperandIdx = - opOperand.getOperandNumber() - launchOp.getDependencies().size(); - mlir::CallOpInterface callOp(op); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - - if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) - // FuncOp not analyzed yet. Assume that OpOperand is read. - return true; - - const FuncAnalysisState &funcState = getFuncAnalysisState(state); - return funcState.readBbArgs.lookup(funcOp).contains(opOperandIdx); - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - // operands are always inputs - return false; - } - - AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - // results never alias with operands - AliasingValueList result; - return result; - } - - /// All function arguments are writable. It is the responsibility of the - /// CallOp to insert buffer copies where necessary. - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options, - BufferizationState &state) const { - mlir::CallOpInterface callOp(op); - auto callOperands = callOp.getArgOperands(); - auto callResultTypes = callOp.getCallResultTypes(); - unsigned numOperands = callOp->getNumOperands(); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - - // Result types of the bufferized CallOp. - SmallVector resultTypes; - - // Operands of the bufferized CallOp. - SmallVector newOperands(numOperands, Value()); - - // 1. Compute the result types of the new CallOp. - unsigned funcResultIdx = 0; - for (const auto &it : llvm::enumerate(callOp->getResults())) { - auto returnVal = it.value(); - Type returnType = returnVal.getType(); - if (isa(returnType)) { - assert(returnType == callResultTypes[funcResultIdx++]); - FailureOr bufferType = - bufferization::getBufferType(returnVal, options, state); - if (failed(bufferType)) - return failure(); - assert(isa(*bufferType) && "expected memref type"); - BaseMemRefType memrefType = cast(*bufferType); - resultTypes.push_back(memrefType); - } else { - // Non-tensor values are returned. - resultTypes.push_back(returnType); - if (returnType == callResultTypes[funcResultIdx]) - funcResultIdx++; - } - } - - // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. - unsigned funcOperandIdx = 0; - for (OpOperand &opOperand : callOp->getOpOperands()) { - unsigned idx = opOperand.getOperandNumber(); - Value tensorOperand = opOperand.get(); - // Non-tensor operands are just copied. - if (isa(tensorOperand.getType())) { - newOperands[idx] = tensorOperand; - continue; - } - if (!isa(tensorOperand.getType())) { - newOperands[idx] = tensorOperand; - if (tensorOperand == callOperands[funcOperandIdx]) - funcOperandIdx++; - continue; - } - - // Retrieve buffers for tensor operands. - Value buffer = newOperands[idx]; - if (!buffer) { - FailureOr maybeBuffer = - getBuffer(rewriter, opOperand.get(), options, state); - if (failed(maybeBuffer)) - return failure(); - buffer = *maybeBuffer; - } - newOperands[idx] = buffer; - funcOperandIdx++; - } - - // 3. Create the new CallOp. - Operation *newCallOp = - callOp.clone(rewriter, callOp.getLoc(), resultTypes, newOperands); - - // 4. Replace the old op with the new op. - replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults()); - - return success(); - } -}; - -} // namespace -} // namespace mhal -} // namespace mlir - -void mlir::mhal::registerBufferizableOpInterfaceExternalModels( - DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, mhal::MHALDialect *dialect) { - mhal::LaunchOp::attachInterface(*ctx); - }); -} diff --git a/external/mlir-hal/lib/Dialect/MHAL/Transforms/Bufferize.cpp b/external/mlir-hal/lib/Dialect/MHAL/Transforms/Bufferize.cpp index 1e461e2196e4..3491d09222ba 100644 --- a/external/mlir-hal/lib/Dialect/MHAL/Transforms/Bufferize.cpp +++ b/external/mlir-hal/lib/Dialect/MHAL/Transforms/Bufferize.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MHAL/IR/MHAL.h" -#include "mlir/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" @@ -43,7 +42,6 @@ struct MHALBufferizePass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); - mhal::registerBufferizableOpInterfaceExternalModels(registry); } }; } // namespace diff --git a/external/mlir-hal/lib/Dialect/MHAL/Transforms/CMakeLists.txt b/external/mlir-hal/lib/Dialect/MHAL/Transforms/CMakeLists.txt index bd20f381b941..bc31eb296962 100644 --- a/external/mlir-hal/lib/Dialect/MHAL/Transforms/CMakeLists.txt +++ b/external/mlir-hal/lib/Dialect/MHAL/Transforms/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_dialect_library(MLIRMHALTransforms Bufferize.cpp - BufferizableOpInterfaceImpl.cpp DropMetadata.cpp EmulateNarrowType.cpp PackageTargets.cpp diff --git a/external/mlir-hal/lib/Dialect/MHAL/Transforms/EmulateNarrowType.cpp b/external/mlir-hal/lib/Dialect/MHAL/Transforms/EmulateNarrowType.cpp index 02e55d9ae445..01b4bc863215 100644 --- a/external/mlir-hal/lib/Dialect/MHAL/Transforms/EmulateNarrowType.cpp +++ b/external/mlir-hal/lib/Dialect/MHAL/Transforms/EmulateNarrowType.cpp @@ -85,25 +85,6 @@ class ExtractStridedMetadataFromOldFuncArgs return success(); } }; - -struct MHalLaunchOpRewritePattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(LaunchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector newReturnTypes; - if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), - newReturnTypes))) - return rewriter.notifyMatchFailure( - op, "failed to convert result type of launched function"); - rewriter.replaceOpWithNewOp( - op, newReturnTypes.front(), ArrayRef(newReturnTypes).drop_front(), - adaptor.getCallee(), adaptor.getDependencies(), - adaptor.getLaunchOperands()); - return success(); - } -}; } // end namespace void mlir::mhal::populateMHalNarrowTypeEmulationConversions( @@ -124,13 +105,6 @@ void mlir::mhal::populateMHalNarrowTypeEmulationConversions( typeConverter.addTargetMaterialization(materializer); } -void mlir::mhal::populateMHalNarrowTypeEmulationBoundaryPatterns( - arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); -} - void mlir::mhal::populateMHalNarrowTypeEmulationPatterns( arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) { @@ -157,11 +131,6 @@ void MHalEmulateNarrowTypePass::runOnOperation() { [&typeConverter](func::FuncOp op) { return typeConverter.isLegal(op.getFunctionType()); }); - boundaryTarget.addDynamicallyLegalOp( - [&typeConverter](mhal::LaunchOp op) { - return typeConverter.isLegal(op.getCallResultTypes()) && - typeConverter.isLegal(op.getOperandTypes()); - }); boundaryTarget.addDynamicallyLegalOp( opLegalCallback); @@ -175,8 +144,6 @@ void MHalEmulateNarrowTypePass::runOnOperation() { RewritePatternSet boundaryPatterns(ctx); arith::populateArithNarrowTypeEmulationPatterns(typeConverter, boundaryPatterns); - mhal::populateMHalNarrowTypeEmulationBoundaryPatterns(typeConverter, - boundaryPatterns); if (failed(applyPartialConversion(op, boundaryTarget, std::move(boundaryPatterns)))) return signalPassFailure(); diff --git a/mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h b/mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h index 27f0d396bd0d..4b04fc85ba12 100644 --- a/mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h +++ b/mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h @@ -13,8 +13,6 @@ #ifndef MLIR_CONVERSION_MIGRAPHXTOLINALG_H #define MLIR_CONVERSION_MIGRAPHXTOLINALG_H -#include "mlir/Dialect/MIGraphX/IR/MIGraphX.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -40,9 +38,6 @@ void populateMIGraphXToLinalgBoundaryDialectConversion( void populateMIGraphXFuncBoundaryToLinalgConversionPatterns( RewritePatternSet &target, TypeConverter &typeConverter); -/// Populates conversion patterns for function boundaries mhal.launcher -void populateMIGraphXToLinalgMHALLauncherConversion( - RewritePatternSet &target, TypeConverter &typeConverter); } // namespace migraphx } // namespace mlir diff --git a/mlir/include/mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h b/mlir/include/mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h index 2c2912a1d083..4560a61707fe 100644 --- a/mlir/include/mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h +++ b/mlir/include/mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h @@ -16,12 +16,8 @@ #ifndef MLIR_CONVERSION_MIGRAPHXTOTOSA_H #define MLIR_CONVERSION_MIGRAPHXTOTOSA_H -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MHAL/IR/MHAL.h" -#include "mlir/Dialect/MIGraphX/IR/MIGraphX.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/RocMLIRPasses.td b/mlir/include/mlir/Conversion/RocMLIRPasses.td index aa9458d18a89..5801d4d350a8 100644 --- a/mlir/include/mlir/Conversion/RocMLIRPasses.td +++ b/mlir/include/mlir/Conversion/RocMLIRPasses.td @@ -124,11 +124,7 @@ def MIGraphXToTosaPass : Pass<"migraphx-to-tosa", "::mlir::func::FuncOp"> { Pass that converts MIGraphX operations to TOSA operations. }]; - let dependentDialects = [ - "func::FuncDialect", - "tosa::TosaDialect", - "mhal::MHALDialect", - ]; + let dependentDialects = ["func::FuncDialect", "tosa::TosaDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp index 5fe0361bf9f8..3a7ebd3b6f0d 100644 --- a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp +++ b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h" -#include "mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -2191,10 +2190,6 @@ void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns( RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add( typeConverter, patterns.getContext()); - - // mhal.launch can be generated through rocmlir-gen, so we need a way to - // legalize it - populateMIGraphXToLinalgMHALLauncherConversion(patterns, typeConverter); populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); populateReturnOpTypeConversionPattern(patterns, typeConverter); populateCallOpTypeConversionPattern(patterns, typeConverter); diff --git a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp index abb44768253b..ab084bcf7e8b 100644 --- a/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp +++ b/mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalgPass.cpp @@ -15,11 +15,10 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MIGraphX/IR/MIGraphX.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Rock/IR/Rock.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -53,11 +52,6 @@ void mlir::migraphx::populateMIGraphXToLinalgBoundaryDialectConversion( target.addDynamicallyLegalOp([&](func::FuncOp op) { return typeConverter.isSignatureLegal(op.getFunctionType()); }); - target.addDynamicallyLegalOp( - [=](mhal::LaunchOp op) -> std::optional { - return typeConverter.isLegal(op.getResultTypes()) && - typeConverter.isLegal(op.getOperandTypes()); - }); target.addDynamicallyLegalOp( [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); target.addDynamicallyLegalOp( @@ -97,8 +91,6 @@ void MIGraphXToLinalgPass::runOnOperation() { boundaryConversionTarget, boundaryTypeConverter); migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns( boundaryPattern, boundaryTypeConverter); - migraphx::populateMIGraphXToLinalgMHALLauncherConversion( - boundaryPattern, boundaryTypeConverter); if (failed(applyPartialConversion(func, boundaryConversionTarget, std::move(boundaryPattern)))) { return signalPassFailure(); diff --git a/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp b/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp index e1b19efcdf86..a01843f37647 100644 --- a/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp +++ b/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MHAL/IR/MHAL.h" #include "mlir/Dialect/MIGraphX/IR/MIGraphX.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Rock/IR/Rock.h" @@ -32,7 +31,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/APFloat.h" #include "llvm/Support/Debug.h" @@ -1560,14 +1558,6 @@ struct AsUnderlyingShapeConverter final ConversionPatternRewriter &rewriter) const final; }; -/// This mirrors the call op conversion pattern but works for mhal.launch. -struct MHALLaunchConverter final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mhal::LaunchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final; -}; } // namespace LogicalResult AsLogicalShapeConverter::matchAndRewrite( @@ -1709,26 +1699,6 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite( return success(); } -LogicalResult MHALLaunchConverter::matchAndRewrite( - mhal::LaunchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // Convert the original function results. - SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - - // If this isn't a one-to-one type mapping, we don't know how to aggregate - // the results. - if (op->getNumResults() != resultTypes.size()) - return failure(); - - // Substitute with the new result types from the corresponding FuncType - // conversion. - rewriter.replaceOpWithNewOp( - op, op.getCalleeAttr(), resultTypes, adaptor.getOperands()); - return success(); -} - //===----------------------------------------------------------------------===// // External interface //===----------------------------------------------------------------------===// @@ -1765,13 +1735,9 @@ void migraphx::populateMIGraphXToTosaConversionPatterns( void mlir::migraphx::populateMIGraphXFuncBoundaryToTosaConversionPatterns( RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add, - MHALLaunchConverter>(typeConverter, patterns.getContext()); + TrivialConverter>( + typeConverter, patterns.getContext()); // Add upstream patterns that take care of func.func and its friends. populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); populateCallOpTypeConversionPattern(patterns, typeConverter); } -void mlir::migraphx::populateMIGraphXToLinalgMHALLauncherConversion( - RewritePatternSet &patterns, TypeConverter &typeConverter) { - patterns.add(typeConverter, patterns.getContext()); -} diff --git a/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosaPass.cpp b/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosaPass.cpp index 76573dbcedf4..fde9fd3d72bc 100644 --- a/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosaPass.cpp +++ b/mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosaPass.cpp @@ -14,15 +14,11 @@ #include "mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MIGraphX/IR/MIGraphX.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" #define DEBUG_TYPE "migraphx-to-tosa" @@ -62,11 +58,6 @@ void mlir::migraphx::populateMIGraphXFuncBoundaryToTosaDialectConversion( [=](func::CallOp op) -> std::optional { return typeConverter->isSignatureLegal(op.getCalleeType()); }); - target.addDynamicallyLegalOp( - [=](mhal::LaunchOp op) -> std::optional { - return typeConverter->isLegal(op.getResultTypes()) && - typeConverter->isLegal(op.getOperandTypes()); - }); target.addDynamicallyLegalOp( [=](func::ReturnOp op) -> std::optional { return typeConverter->isLegal(op); diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index 5349f4447c55..8d1402eecf06 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -62,7 +62,7 @@ void rock::buildBufferizePipeline(OpPassManager &pm, bool noRock = options.disableRock; auto &funcPm = pm.nest(); - // TOSA conversion to rock and/or linalg with mhal.launch's + // TOSA conversion to rock and/or linalg with func.call if (!noRock) { // convert tosa.conv2d/matmul to rock.conv /* rocmlir-opt --tosa-to-tensor --tosa-to-rock --rock-view-to-transform diff --git a/mlir/lib/Dialect/Rock/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Rock/Transforms/EmulateNarrowType.cpp index 23967735b23a..afd7d12d3cff 100644 --- a/mlir/lib/Dialect/Rock/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/EmulateNarrowType.cpp @@ -499,8 +499,6 @@ void RockEmulateNarrowTypePass::runOnOperation() { RewritePatternSet boundaryPatterns(ctx); arith::populateArithNarrowTypeEmulationPatterns(typeConverter, boundaryPatterns); - mhal::populateMHalNarrowTypeEmulationBoundaryPatterns(typeConverter, - boundaryPatterns); if (failed(applyPartialConversion(op, boundaryTarget, std::move(boundaryPatterns)))) return signalPassFailure(); diff --git a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp index 1ba83bdde8ff..40c3239628f3 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp @@ -21,6 +21,7 @@ //===-----------------------------------------------------===// #include "mlir/Analysis/BufferDependencyAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MHAL/IR/MHAL.h" #include "mlir/Dialect/Rock/IR/AmdArchDb.h" #include "mlir/Dialect/Rock/IR/GemmSize.h" diff --git a/mlir/test/Dialect/MHAL/emulate-narrow-type.mlir b/mlir/test/Dialect/MHAL/emulate-narrow-type.mlir index 448d7a40405e..0749e6b14108 100644 --- a/mlir/test/Dialect/MHAL/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MHAL/emulate-narrow-type.mlir @@ -19,13 +19,11 @@ func.func @foo(%arg0: memref<8xi4>) -> memref<8xi4> { // CHECK-LABEL: func.func @foo_wrapper // CHECK-SAME: (%[[ARG0:.+]]: memref<4xi8>) // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8> -// CHECK-NEXT: %[[TOKEN:.+]], %[[V:.+]] = mhal.launch @foo (%[[ALLOC]]) : (memref<4xi8>) -> memref<4xi8> -// CHECK-NEXT: mhal.await %[[TOKEN]] +// CHECK-NEXT: %[[V:.+]] = call @foo(%[[ALLOC]]) : (memref<4xi8>) -> memref<4xi8> // CHECK-NEXT: memref.copy %[[V]], %[[ARG0]] func.func @foo_wrapper(%arg0: memref<8xi4>) { %alloc = memref.alloc() : memref<8xi4> - %token, %v = mhal.launch @foo(%alloc) : (memref<8xi4>) -> memref<8xi4> - mhal.await %token : !mhal.token + %v = func.call @foo(%alloc) : (memref<8xi4>) -> memref<8xi4> memref.copy %v, %arg0 : memref<8xi4> to memref<8xi4> return } diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case1.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case1.mlir index 0c5563e2acaa..d91856be03f8 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case1.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case1.mlir @@ -23,8 +23,7 @@ module { } func.func @test_reduce(%arg0: memref<5x4x3xf32>, %arg1: memref<1x4x3xf32> {mhal.read_access, mhal.write_access}) attributes {rock.arch = ""} { call @init_output (%arg1) : (memref<1x4x3xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<5x4x3xf32>, memref<1x4x3xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<5x4x3xf32>, memref<1x4x3xf32>) -> () return } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case2.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case2.mlir index d464ed0895d4..237bce1a4561 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case2.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case2.mlir @@ -23,8 +23,7 @@ module { } func.func @test_reduce(%arg0: memref<10x30x20xf32>, %arg1: memref<10x1x20xf32> {mhal.read_access, mhal.write_access}) attributes {rock.arch = ""} { call @init_output (%arg1) : (memref<10x1x20xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<10x30x20xf32>, memref<10x1x20xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<10x30x20xf32>, memref<10x1x20xf32>) -> () return } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##",mhal.module} { diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case3.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case3.mlir index 75477d14a73c..5dc40f206e0f 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case3.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case3.mlir @@ -23,8 +23,7 @@ module { } func.func @test_reduce(%arg0: memref<20x30x10xf32>, %arg1: memref<1x30x10xf32> {mhal.read_access, mhal.write_access}) attributes {rock.arch = ""} { call @init_output (%arg1) : (memref<1x30x10xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<20x30x10xf32>, memref<1x30x10xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<20x30x10xf32>, memref<1x30x10xf32>) -> () return } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case4.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case4.mlir index 287bfbbe610e..f71fd291d460 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case4.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_max/rock-reduce-max-case4.mlir @@ -23,8 +23,7 @@ module { } func.func @test_reduce(%arg0: memref<1000x250x100xf32>, %arg1: memref<1x250x100xf32> {mhal.read_access, mhal.write_access}) attributes {rock.arch = ""} { call @init_output (%arg1) : (memref<1x250x100xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<1000x250x100xf32>, memref<1x250x100xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<1000x250x100xf32>, memref<1x250x100xf32>) -> () return } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case1.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case1.mlir index 8b65aa6bc836..d20a4cf1b537 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case1.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case1.mlir @@ -36,8 +36,7 @@ module { func.func @test_reduce(%arg0: memref<2x3x40xf32>, %arg1: memref<2x3x1xf32>) attributes {rock.arch = ""} { call @zero_init (%arg1) : (memref<2x3x1xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<2x3x40xf32>, memref<2x3x1xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<2x3x40xf32>, memref<2x3x1xf32>) -> () return } diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case2.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case2.mlir index 151a0685788f..88d4ef8729cc 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case2.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case2.mlir @@ -23,8 +23,7 @@ module { } func.func @test_reduce(%arg0: memref<10x30x20xf32>, %arg1: memref<10x1x20xf32>) attributes {rock.arch = ""} { call @zero_init (%arg1) : (memref<10x1x20xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<10x30x20xf32>, memref<10x1x20xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<10x30x20xf32>, memref<10x1x20xf32>) -> () return } module @__xmodule_gfx90a attributes {mhal.arch = "gfx90a",mhal.module} { diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case3.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case3.mlir index c6ff957af073..90155290579c 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case3.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case3.mlir @@ -23,8 +23,7 @@ module { } func.func @test_reduce(%arg0: memref<20x30x10xf32>, %arg1: memref<1x30x10xf32>) attributes {rock.arch = ""} { call @zero_init (%arg1) : (memref<1x30x10xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<20x30x10xf32>, memref<1x30x10xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<20x30x10xf32>, memref<1x30x10xf32>) -> () return } module @__xmodule_gfx90a attributes {mhal.arch = "gfx90a", mhal.module} { diff --git a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case4.mlir b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case4.mlir index d75c374c42b9..f3f00609aa33 100644 --- a/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case4.mlir +++ b/mlir/test/Dialect/Rock/integration/reduce/reduce_sum/gfx90a/rock-reduce-case4.mlir @@ -23,8 +23,7 @@ module { } func.func @test_reduce(%arg0: memref<1000x250x100xf32>, %arg1: memref<1x250x100xf32>) attributes {rock.arch = ""} { call @zero_init (%arg1) : (memref<1x250x100xf32>) -> () - %token1 = mhal.launch @test_reduce__part_1 (%arg0, %arg1) : (memref<1000x250x100xf32>, memref<1x250x100xf32>) - mhal.await %token1 : !mhal.token + func.call @test_reduce__part_1(%arg0, %arg1) : (memref<1000x250x100xf32>, memref<1x250x100xf32>) -> () return } module @__xmodule_gfx90a attributes {mhal.arch = "gfx90a", mhal.module} { diff --git a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case1.mlir b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case1.mlir index 3ca058ba3d32..d128cf9368bd 100644 --- a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case1.mlir +++ b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case1.mlir @@ -8,8 +8,7 @@ module { return %2 : !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1> } func.func @mlir_attention_wrapper(%arg0: !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1>, %arg1: !migraphx.shaped<1x12x64x384xf16, 294912x24576x384x1>, %arg2: !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1>) -> !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1> { - %token, %results = mhal.launch @mlir_attention (%arg0, %arg1, %arg2) : (!migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1>, !migraphx.shaped<1x12x64x384xf16, 294912x24576x384x1>, !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1>) -> !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_attention(%arg0, %arg1, %arg2) : (!migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1>, !migraphx.shaped<1x12x64x384xf16, 294912x24576x384x1>, !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1>) -> !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1> return %results : !migraphx.shaped<1x12x384x64xf16, 294912x24576x64x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case2.mlir b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case2.mlir index 352994b08e2a..8f289748e731 100644 --- a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case2.mlir +++ b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f16/mixr-attention-tier1-f16-case2.mlir @@ -8,8 +8,7 @@ module { return %2 : !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1> } func.func @mlir_attention_wrapper(%arg0: !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1>, %arg1: !migraphx.shaped<2x16x64x384xf16, 393216x24576x384x1>, %arg2: !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1>) -> !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1> { - %token, %results = mhal.launch @mlir_attention (%arg0, %arg1, %arg2) : (!migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1>, !migraphx.shaped<2x16x64x384xf16, 393216x24576x384x1>, !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1>) -> !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_attention(%arg0, %arg1, %arg2) : (!migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1>, !migraphx.shaped<2x16x64x384xf16, 393216x24576x384x1>, !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1>) -> !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1> return %results : !migraphx.shaped<2x16x384x64xf16, 393216x24576x64x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case1.mlir b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case1.mlir index 3ce8eeafa86a..0a80a651feaf 100644 --- a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case1.mlir +++ b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case1.mlir @@ -8,8 +8,7 @@ module { return %2 : !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1> } func.func @mlir_attention_wrapper(%arg0: !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1>, %arg1: !migraphx.shaped<1x12x64x384xf32, 294912x24576x384x1>, %arg2: !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1>) -> !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1> { - %token, %results = mhal.launch @mlir_attention (%arg0, %arg1, %arg2) : (!migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1>, !migraphx.shaped<1x12x64x384xf32, 294912x24576x384x1>, !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1>) -> !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_attention(%arg0, %arg1, %arg2) : (!migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1>, !migraphx.shaped<1x12x64x384xf32, 294912x24576x384x1>, !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1>) -> !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1> return %results : !migraphx.shaped<1x12x384x64xf32, 294912x24576x64x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case2.mlir b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case2.mlir index 380d2aae0bfb..49e1e52d538f 100644 --- a/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case2.mlir +++ b/mlir/test/fusion/nightly-misc-e2e/mixr-attention/f32/mixr-attention-tier1-f32-case2.mlir @@ -8,8 +8,7 @@ module { return %2 : !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1> } func.func @mlir_attention_wrapper(%arg0: !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1>, %arg1: !migraphx.shaped<2x16x64x384xf32, 393216x24576x384x1>, %arg2: !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1>) -> !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1> { - %token, %results = mhal.launch @mlir_attention (%arg0, %arg1, %arg2) : (!migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1>, !migraphx.shaped<2x16x64x384xf32, 393216x24576x384x1>, !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1>) -> !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_attention(%arg0, %arg1, %arg2) : (!migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1>, !migraphx.shaped<2x16x64x384xf32, 393216x24576x384x1>, !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1>) -> !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1> return %results : !migraphx.shaped<2x16x384x64xf32, 393216x24576x64x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/mixr-conv-bias-clipped-relu.mlir b/mlir/test/fusion/pr-e2e/mixr-conv-bias-clipped-relu.mlir index f49bff60b5bb..f93e8b760b11 100644 --- a/mlir/test/fusion/pr-e2e/mixr-conv-bias-clipped-relu.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-conv-bias-clipped-relu.mlir @@ -13,8 +13,7 @@ module { return %7 : !migraphx.shaped<4x4x1x1xf32, 4x1x1x1> } func.func @mlir_convolution_add_clip_wrapper(%arg0: !migraphx.shaped<1x4x1x1xf32, 4x1x1x1>, %arg1: !migraphx.shaped<4x3x3x3xf32, 27x9x3x1>, %arg2: !migraphx.shaped<4x3x3x3xf32, 27x9x3x1>) -> !migraphx.shaped<4x4x1x1xf32, 4x1x1x1> { - %token, %results = mhal.launch @mlir_convolution_add_clip (%arg0, %arg1, %arg2) : (!migraphx.shaped<1x4x1x1xf32, 4x1x1x1>, !migraphx.shaped<4x3x3x3xf32, 27x9x3x1>, !migraphx.shaped<4x3x3x3xf32, 27x9x3x1>) -> !migraphx.shaped<4x4x1x1xf32, 4x1x1x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_convolution_add_clip(%arg0, %arg1, %arg2) : (!migraphx.shaped<1x4x1x1xf32, 4x1x1x1>, !migraphx.shaped<4x3x3x3xf32, 27x9x3x1>, !migraphx.shaped<4x3x3x3xf32, 27x9x3x1>) -> !migraphx.shaped<4x4x1x1xf32, 4x1x1x1> return %results : !migraphx.shaped<4x4x1x1xf32, 4x1x1x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/mixr-dot-int4-f16.mlir b/mlir/test/fusion/pr-e2e/mixr-dot-int4-f16.mlir index b1d827f29e00..4ebc5443ad26 100644 --- a/mlir/test/fusion/pr-e2e/mixr-dot-int4-f16.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-dot-int4-f16.mlir @@ -1,5 +1,5 @@ // RUN: rocmlir-driver -kernel-pipeline=migraphx %s | rocmlir-gen -fut mlir_unpack_dequantizelinear_dot --arch %arch --clone-harness - | FileCheck %s --check-prefix=HASINT4 -// HASINT4: mhal.launch +// HASINT4: call @mlir_unpack_dequantizelinear_dot // HASINT4-SAME: tensor<64xi4> // RUN: rocmlir-driver -kernel-pipeline=migraphx %s | rocmlir-gen -fut mlir_unpack_dequantizelinear_dot --arch %arch --clone-harness - | rocmlir-driver -host-pipeline=highlevel -kernel-pipeline=highlevel | rocmlir-gen -ph -fut mlir_unpack_dequantizelinear_dot_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s diff --git a/mlir/test/fusion/pr-e2e/mixr-expand-strides-non-multiple.mlir b/mlir/test/fusion/pr-e2e/mixr-expand-strides-non-multiple.mlir index e3e23f8dadad..d8e3ce8fbd18 100644 --- a/mlir/test/fusion/pr-e2e/mixr-expand-strides-non-multiple.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-expand-strides-non-multiple.mlir @@ -1,10 +1,9 @@ -// RUN: rocmlir-gen -fut mlir_dot_sigmoid --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_sigmoid_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_dot_sigmoid --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_sigmoid_wrapper --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s -// Only ~41% of the results will be correct since the non-contiguous strides -// in this example is the result of concating 4x5x24 and 4x7x24 (i.e.m 4x12x24). -// This kernel is specificall dealing with making sure that the 4x5x24 elements -// are all in the correct place in the larger tensor. -// CHECK: relDiff = 0 : 480/1152 (41.666667%) +// Non-contiguous strides from a concat-like layout (4x5x24 and 4x7x24 in a +// 4x12x24 logical tensor). Ensures dot+sigmoid + expand-strides placement; +// with correct GPU staging, GPU matches CPU on the full output. +// CHECK: relDiff = 0 : 1152/1152 (100.000000%) module { func.func @mlir_dot_sigmoid(%arg0: !migraphx.shaped<4x5x16xf16, 80x16x1>, %arg1: !migraphx.shaped<4x16x24xf16, 384x24x1>) -> !migraphx.shaped<4x5x24xf16, 288x24x1> attributes {rock.arch = "gfx1201", rock.kernel = "mixr", rock.num_cu = 32 : i64} { diff --git a/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding.mlir b/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding.mlir index 16e8b05bf86a..21358673a983 100644 --- a/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding.mlir @@ -9,8 +9,7 @@ module { return %3 : !migraphx.shaped<2x16x8xf32, 128x8x1> } func.func @mlir_transpose_reshape_dot_wrapper(%arg0: !migraphx.shaped<2x8x4x4xf32, 128x16x4x1>, %arg1: !migraphx.shaped<1x8x8xf32, 64x8x1> ) -> !migraphx.shaped<2x16x8xf32, 128x8x1> { - %token, %results = mhal.launch @mlir_transpose_reshape_dot (%arg0, %arg1) : (!migraphx.shaped<2x8x4x4xf32, 128x16x4x1>, !migraphx.shaped<1x8x8xf32, 64x8x1>) -> !migraphx.shaped<2x16x8xf32, 128x8x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_transpose_reshape_dot(%arg0, %arg1) : (!migraphx.shaped<2x8x4x4xf32, 128x16x4x1>, !migraphx.shaped<1x8x8xf32, 64x8x1>) -> !migraphx.shaped<2x16x8xf32, 128x8x1> return %results : !migraphx.shaped<2x16x8xf32, 128x8x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding2.mlir b/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding2.mlir index 090eec88d388..dc71df76bd03 100644 --- a/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding2.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-gemm/mixr-gemm-tr-folding2.mlir @@ -8,8 +8,7 @@ module { return %2 : !migraphx.shaped<1x6xf32, 6x1> } func.func @mlir_transpose_reshape_dot_wrapper(%arg0: !migraphx.shaped<1x2x1x3xf32, 6x3x3x1>, %arg1: !migraphx.shaped<6x6xf32, 6x1>) -> !migraphx.shaped<1x6xf32, 6x1> { - %token, %results = mhal.launch @mlir_transpose_reshape_dot (%arg0, %arg1) : (!migraphx.shaped<1x2x1x3xf32, 6x3x3x1>, !migraphx.shaped<6x6xf32, 6x1>) -> !migraphx.shaped<1x6xf32, 6x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_transpose_reshape_dot(%arg0, %arg1) : (!migraphx.shaped<1x2x1x3xf32, 6x3x3x1>, !migraphx.shaped<6x6xf32, 6x1>) -> !migraphx.shaped<1x6xf32, 6x1> return %results : !migraphx.shaped<1x6xf32, 6x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides-sub.mlir b/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides-sub.mlir index e4029d7c3ac0..02e75bb649ed 100644 --- a/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides-sub.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides-sub.mlir @@ -1,8 +1,8 @@ -// RUN: rocmlir-gen -fut mlir_dot_log --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_log_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_dot_log --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_log_wrapper --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s -// Only half of the results will be correct since the non-contiguous strides -// in this example means that about half of the memory is uninitialized. -// CHECK: relDiff = 0 : 2304/4608 (50.000000%) +// migraphx.sub %0, %0 yields zeros; verifier should report all elements in the +// relDiff = 0 bucket (100% of elements). +// CHECK: relDiff = 0 : 4608/4608 (100.000000%) module { func.func @mlir_dot_log(%arg0: !migraphx.shaped<4x24x16xf16, 384x16x1>, %arg1: !migraphx.shaped<4x16x24xf16, 384x24x1>) -> !migraphx.shaped<4x24x24xf16, 1152x24x1> { diff --git a/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides.mlir b/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides.mlir index cacc68f79a02..ce8a3b24191f 100644 --- a/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-non-contiguous-strides.mlir @@ -1,10 +1,8 @@ -// RUN: rocmlir-gen -fut mlir_dot_sigmoid --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_sigmoid_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut mlir_dot_sigmoid --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_sigmoid_wrapper --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s -// Only half of the results will be correct since the non-contiguous strides -// in this example means that half of the memory is uninitialized. We expect -// at least 2304/4608 correct; uninitialized positions may also coincidentally -// match, so the count can be slightly higher (at least 50% correct is expected). -// CHECK: relDiff = 0 : {{[0-9]+}}/4608 ({{[5-9][0-9]}}. +// Non-contiguous strides: validates expand-strides / padded layout. With +// correct GPU host staging, CPU and GPU references match on the full tensor. +// CHECK: relDiff = 0 : 4608/4608 (100.000000%) module { func.func @mlir_dot_sigmoid(%arg0: !migraphx.shaped<4x24x16xf16, 384x16x1>, %arg1: !migraphx.shaped<4x16x24xf16, 384x24x1>) -> !migraphx.shaped<4x24x24xf16, 1152x24x1> attributes {rock.kernel = "mixr"} { diff --git a/mlir/test/fusion/pr-e2e/mixr-sd-explicit-broadcasting.mlir b/mlir/test/fusion/pr-e2e/mixr-sd-explicit-broadcasting.mlir index 4db79a0bb43a..63a8816be7fd 100644 --- a/mlir/test/fusion/pr-e2e/mixr-sd-explicit-broadcasting.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-sd-explicit-broadcasting.mlir @@ -8,8 +8,7 @@ module { return %2 : !migraphx.shaped<1x1x32x32xf32, 1024x1024x32x1> } func.func @mlir_reshape_convolution_real(%arg0: !migraphx.shaped<1x1x16x1x16x1xf32, 256x256x16x16x1x1>, %arg1: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x32x32xf32, 1024x1024x32x1> { - %token, %results = mhal.launch @mlir_reshape_convolution_real__part_0 (%arg0, %arg1) : (!migraphx.shaped<1x1x16x1x16x1xf32, 256x256x16x16x1x1>, !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x32x32xf32, 1024x1024x32x1> - mhal.await %token : !mhal.token + %results = func.call @mlir_reshape_convolution_real__part_0(%arg0, %arg1) : (!migraphx.shaped<1x1x16x1x16x1xf32, 256x256x16x16x1x1>, !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x32x32xf32, 1024x1024x32x1> return %results : !migraphx.shaped<1x1x32x32xf32, 1024x1024x32x1> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-add-reduce-sum.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-add-reduce-sum.e2e.mlir index 41aeb1e1b5c7..8ea0c98469df 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-add-reduce-sum.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-add-reduce-sum.e2e.mlir @@ -12,8 +12,7 @@ module { return %2 : tensor<1x128x1xf32> } func.func @dot_add(%arg0: tensor<1x128x64xf32>, %arg1: tensor<1x64x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x1xf32> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1, %arg2) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>, tensor<1x128x256xf32>) -> tensor<1x128x1xf32> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1, %arg2) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>, tensor<1x128x256xf32>) -> tensor<1x128x1xf32> return %results : tensor<1x128x1xf32> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case1.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case1.e2e.mlir index d7461964ce5f..87fe1ebc2530 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case1.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case1.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x128x1xf32> } func.func @dot_add(%arg0: tensor<1x128x64xf32>, %arg1: tensor<1x64x256xf32>) -> tensor<1x128x1xf32> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x128x1xf32> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x128x1xf32> return %results : tensor<1x128x1xf32> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case2.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case2.e2e.mlir index 1cfddcea1796..cea6dbc6da43 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case2.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add/tosa-gemm-reduce-sum-case2.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x1x256xf32> } func.func @dot_add(%arg0: tensor<1x128x64xf32>, %arg1: tensor<1x64x256xf32>) -> tensor<1x1x256xf32> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x1x256xf32> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x1x256xf32> return %results : tensor<1x1x256xf32> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-add-reduce-sum-bf16.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-add-reduce-sum-bf16.e2e.mlir index 22bbf68d9973..3f10a1c3d74c 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-add-reduce-sum-bf16.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-add-reduce-sum-bf16.e2e.mlir @@ -12,8 +12,7 @@ module { return %2 : tensor<1x128x1xbf16> } func.func @dot_add(%arg0: tensor<1x128x64xbf16>, %arg1: tensor<1x64x256xbf16>, %arg2: tensor<1x128x256xbf16>) -> tensor<1x128x1xbf16> { - %token, %results = mhal.launch @dot_add_part_0 (%arg0, %arg1, %arg2) : (tensor<1x128x64xbf16>, tensor<1x64x256xbf16>, tensor<1x128x256xbf16>) -> tensor<1x128x1xbf16> - mhal.await %token : !mhal.token + %results = func.call @dot_add_part_0(%arg0, %arg1, %arg2) : (tensor<1x128x64xbf16>, tensor<1x64x256xbf16>, tensor<1x128x256xbf16>) -> tensor<1x128x1xbf16> return %results : tensor<1x128x1xbf16> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case1-bf16.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case1-bf16.e2e.mlir index 25731f32669d..7608ecd3ac90 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case1-bf16.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case1-bf16.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x128x1xbf16> } func.func @dot_add(%arg0: tensor<1x128x64xbf16>, %arg1: tensor<1x64x256xbf16>) -> tensor<1x128x1xbf16> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xbf16>, tensor<1x64x256xbf16>) -> tensor<1x128x1xbf16> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xbf16>, tensor<1x64x256xbf16>) -> tensor<1x128x1xbf16> return %results : tensor<1x128x1xbf16> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case2-bf16.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case2-bf16.e2e.mlir index 4a23396e7cd3..3bd599e97036 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case2-bf16.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add_bf16/tosa-gemm-reduce-sum-case2-bf16.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x1x256xbf16> } func.func @dot_add(%arg0: tensor<1x128x64xbf16>, %arg1: tensor<1x64x256xbf16>) -> tensor<1x1x256xbf16> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xbf16>, tensor<1x64x256xbf16>) -> tensor<1x1x256xbf16> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xbf16>, tensor<1x64x256xbf16>) -> tensor<1x1x256xbf16> return %results : tensor<1x1x256xbf16> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-add-reduce-sum-f16.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-add-reduce-sum-f16.e2e.mlir index 3324fe2292c8..f413ddc9acc6 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-add-reduce-sum-f16.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-add-reduce-sum-f16.e2e.mlir @@ -12,8 +12,7 @@ module { return %2 : tensor<1x128x1xf16> } func.func @dot_add(%arg0: tensor<1x128x64xf16>, %arg1: tensor<1x64x256xf16>, %arg2: tensor<1x128x256xf16>) -> tensor<1x128x1xf16> { - %token, %results = mhal.launch @dot_add_part_0 (%arg0, %arg1, %arg2) : (tensor<1x128x64xf16>, tensor<1x64x256xf16>, tensor<1x128x256xf16>) -> tensor<1x128x1xf16> - mhal.await %token : !mhal.token + %results = func.call @dot_add_part_0(%arg0, %arg1, %arg2) : (tensor<1x128x64xf16>, tensor<1x64x256xf16>, tensor<1x128x256xf16>) -> tensor<1x128x1xf16> return %results : tensor<1x128x1xf16> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case1-f16.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case1-f16.e2e.mlir index 87f8b14b45f0..24002819508c 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case1-f16.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case1-f16.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x128x1xf16> } func.func @dot_add(%arg0: tensor<1x128x64xf16>, %arg1: tensor<1x64x256xf16>) -> tensor<1x128x1xf16> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xf16>, tensor<1x64x256xf16>) -> tensor<1x128x1xf16> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xf16>, tensor<1x64x256xf16>) -> tensor<1x128x1xf16> return %results : tensor<1x128x1xf16> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case2-f16.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case2-f16.e2e.mlir index ba96892e29b3..e3afe7c42b23 100644 --- a/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case2-f16.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/atomic_add_f16/tosa-gemm-reduce-sum-case2-f16.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x1x256xf16> } func.func @dot_add(%arg0: tensor<1x128x64xf16>, %arg1: tensor<1x64x256xf16>) -> tensor<1x1x256xf16> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xf16>, tensor<1x64x256xf16>) -> tensor<1x1x256xf16> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xf16>, tensor<1x64x256xf16>) -> tensor<1x1x256xf16> return %results : tensor<1x1x256xf16> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-add-reduce-max.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-add-reduce-max.e2e.mlir index a9616cc12652..e81d0e1f6b27 100644 --- a/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-add-reduce-max.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-add-reduce-max.e2e.mlir @@ -12,8 +12,7 @@ module { return %2 : tensor<1x128x1xf32> } func.func @dot_add(%arg0: tensor<1x128x64xf32>, %arg1: tensor<1x64x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x1xf32> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1, %arg2) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>, tensor<1x128x256xf32>) -> tensor<1x128x1xf32> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1, %arg2) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>, tensor<1x128x256xf32>) -> tensor<1x128x1xf32> return %results : tensor<1x128x1xf32> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case1.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case1.e2e.mlir index e294ed79933f..eba01505f2eb 100644 --- a/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case1.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case1.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x128x1xf32> } func.func @dot_add(%arg0: tensor<1x128x64xf32>, %arg1: tensor<1x64x256xf32>) -> tensor<1x128x1xf32> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x128x1xf32> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x128x1xf32> return %results : tensor<1x128x1xf32> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case2.e2e.mlir b/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case2.e2e.mlir index 6cb61889a439..c4d7f51e5893 100644 --- a/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case2.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/reductions/tosa-gemm-reduce-max-case2.e2e.mlir @@ -11,8 +11,7 @@ module { return %1 : tensor<1x1x256xf32> } func.func @dot_add(%arg0: tensor<1x128x64xf32>, %arg1: tensor<1x64x256xf32>) -> tensor<1x1x256xf32> { - %token, %results = mhal.launch @dot_add__part_0 (%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x1x256xf32> - mhal.await %token : !mhal.token + %results = func.call @dot_add__part_0(%arg0, %arg1) : (tensor<1x128x64xf32>, tensor<1x64x256xf32>) -> tensor<1x1x256xf32> return %results : tensor<1x1x256xf32> } module @__xmodule_ attributes {mhal.arch = "##TOKEN_ARCH##", mhal.module} { diff --git a/mlir/test/fusion/pr-e2e/tosa-to-rock-exp.e2e.mlir b/mlir/test/fusion/pr-e2e/tosa-to-rock-exp.e2e.mlir index 440e635d5588..0541e6a5d48e 100644 --- a/mlir/test/fusion/pr-e2e/tosa-to-rock-exp.e2e.mlir +++ b/mlir/test/fusion/pr-e2e/tosa-to-rock-exp.e2e.mlir @@ -1,8 +1,9 @@ -// RUN: rocmlir-gen -fut test_fusion --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut test_fusion_wrapper -rand 1 -rand_type float --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext -entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut test_fusion --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut test_fusion_wrapper -rand 1 -rand_type float --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext -entry-point-result=void | FileCheck %s module { -// CHECK: RMS = {{.*}}e-09 -// CHECK: [1 0 0] +// CHECK: Number of elements: +// CHECK: RMS = {{0\.0e\+00|.*e-0[6-9]}} +// CHECK: [1 1 1] func.func @test_fusion(%arg0: tensor<128x32x32x8xf32>, %arg1: tensor<128x3x3x8xf32>) -> tensor<128x30x30x128xf32> { %zero = arith.constant dense<0.0> : tensor<128xf32> diff --git a/mlir/test/rocmlir-driver/runner-pipelines.mlir b/mlir/test/rocmlir-driver/runner-pipelines.mlir index ee1a1c965f61..35594fdfa676 100644 --- a/mlir/test/rocmlir-driver/runner-pipelines.mlir +++ b/mlir/test/rocmlir-driver/runner-pipelines.mlir @@ -10,7 +10,6 @@ // RUNNER-SAME: convert-scf-to-cf{allow-pattern-rollback=true}), // RUNNER-SAME: func.func(gpu-async-region), // RUNNER-SAME: convert-mhal-to-gpu, -// RUNNER-SAME: convert-mhal-to-cpu, // RUNNER-SAME: async-parallel-for{async-dispatch=true min-task-size=1000 num-workers=8}, // RUNNER-SAME: func.func(arith-expand{include-bf16=false include-f4e2m1=true include-f8e8m0=true include-float-min-max=true include-flush-denormals=false}, // RUNNER-SAME: convert-arith-to-llvm{index-bitwidth=0}, diff --git a/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part0.mlir b/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part0.mlir index 0700b6e80ace..5e87d484b6e8 100644 --- a/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part0.mlir +++ b/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part0.mlir @@ -1,8 +1,9 @@ -// RUN: rocmlir-gen -fut forward__part_0 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_0_wrapper -rand 1 -rand_type float --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut forward__part_0 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_0_wrapper -rand 1 -rand_type float --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s -// CHECK: RMS = {{.*}}e-07 -// CHECK: [1 1 0] +// CHECK: Number of elements: +// CHECK: RMS = {{0\.0e\+00|.*e-0[6-9]}} +// CHECK: [1 1 1] module { func.func private @forward__part_0(%arg0: tensor<1x64x56x56xf32> {mhal.read_access}, %arg1: tensor<64x64x3x3xf32> {mhal.read_access}, %arg2: tensor<1x64x1x1xf32> {mhal.read_access}, %arg3: tensor<64x1x1xf32> {mhal.read_access}, %arg4: tensor<1x64x1x1xf32> {mhal.read_access}, %arg5: tensor<1x64x1x1xf32> {mhal.read_access}, %arg6: tensor<1x56x56x64xf32> {mhal.read_access}) -> (tensor<1x64x56x56xf32> {mhal.write_access}) { diff --git a/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part2.mlir b/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part2.mlir index acc5b19a248c..4a7c142ba588 100644 --- a/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part2.mlir +++ b/mlir/test/xmir/pr-e2e/resnet18_blk1/resnet18_blk_part2.mlir @@ -1,8 +1,9 @@ -// RUN: rocmlir-gen -fut forward__part_2 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_2_wrapper -rand 1 -rand_type float --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut forward__part_2 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_2_wrapper -rand 1 -rand_type float --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s -// CHECK: RMS = {{.*}}e-08 -// CHECK: [1 1 0] +// CHECK: Number of elements: +// CHECK: RMS = {{0\.0e\+00|.*e-0[6-9]}} +// CHECK: [1 1 1] module { func.func private @forward__part_2(%arg0: tensor<1x3x224x224xf32> {mhal.read_access}, %arg1: tensor<64x3x7x7xf32> {mhal.read_access}, %arg2: tensor<1x64x1x1xf32> {mhal.read_access}, %arg3: tensor<64x1x1xf32> {mhal.read_access}, %arg4: tensor<1x64x1x1xf32> {mhal.read_access}, %arg5: tensor<1x64x1x1xf32> {mhal.read_access}) -> (tensor<1x112x112x64xf32> {mhal.write_access}) { diff --git a/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part1.mlir b/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part1.mlir index 2c90bf41b197..f5042ff03040 100644 --- a/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part1.mlir +++ b/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part1.mlir @@ -1,8 +1,9 @@ -// RUN: rocmlir-gen -fut forward__part_1 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_1_wrapper -rand 1 -rand_type float --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut forward__part_1 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_1_wrapper -rand 1 -rand_type float --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s -// CHECK: RMS = {{.*}}e-07 -// CHECK: [1 1 0] +// CHECK: Number of elements: +// CHECK: RMS = {{0\.0e\+00|.*e-0[6-9]}} +// CHECK: [1 1 1] module { func.func private @forward__part_1(%arg0: tensor<1x128x28x28xf32> {mhal.read_access}, %arg1: tensor<128x128x3x3xf32> {mhal.read_access}, %arg2: tensor<1x128x1x1xf32> {mhal.read_access}, %arg3: tensor<1x128x1x1xf32> {mhal.read_access}, %arg4: tensor<1x128x1x1xf32> {mhal.read_access}, %arg5: tensor<1x128x1x1xf32> {mhal.read_access}) -> (tensor<1x128x28x28xf32> {mhal.write_access}) { diff --git a/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part2.mlir b/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part2.mlir index e5a6b63dd02e..709b8a32e172 100644 --- a/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part2.mlir +++ b/mlir/test/xmir/pr-e2e/resnet18_blk3/resnet18_blk3_part2.mlir @@ -1,8 +1,9 @@ -// RUN: rocmlir-gen -fut forward__part_2 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_2_wrapper -rand 1 -rand_type float --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: rocmlir-gen -fut forward__part_2 --arch %arch --clone-harness %s | rocmlir-driver -host-pipeline highlevel -kernel-pipeline highlevel | rocmlir-gen -ph -fut forward__part_2_wrapper -rand 1 -rand_type float --verifier clone -print-verify-results=always - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s -// CHECK: RMS = {{.*}}e-07 -// CHECK: [1 1 0] +// CHECK: Number of elements: +// CHECK: RMS = {{0\.0e\+00|.*e-0[6-9]}} +// CHECK: [1 1 1] module { func.func private @forward__part_2(%arg0: tensor<1x64x56x56xf32> {mhal.read_access}, %arg1: tensor<128x64x3x3xf32> {mhal.read_access}, %arg2: tensor<1x128x1x1xf32> {mhal.read_access}, %arg3: tensor<1x128x1x1xf32> {mhal.read_access}, %arg4: tensor<1x128x1x1xf32> {mhal.read_access}, %arg5: tensor<1x128x1x1xf32> {mhal.read_access}) -> (tensor<1x128x28x28xf32> {mhal.write_access}) { diff --git a/mlir/test/xmir/xmir-runner.fail.mlir b/mlir/test/xmir/xmir-runner.fail.mlir index 06cbadd2e636..15a8b8133d0e 100644 --- a/mlir/test/xmir/xmir-runner.fail.mlir +++ b/mlir/test/xmir/xmir-runner.fail.mlir @@ -35,8 +35,7 @@ module { return } func.func @resnet50(%arg0: memref<1x32x32x64xf32>, %arg1: memref<64x3x3x64xf32>, %arg2: memref<64x3x3x64xf32>, %arg3: memref<1x32x32x64xf32>) { - %token = mhal.launch @resnet50__part_0 (%arg0, %arg1, %arg3) : (memref<1x32x32x64xf32>, memref<64x3x3x64xf32>, memref<1x32x32x64xf32>) - mhal.await %token : !mhal.token + func.call @resnet50__part_0(%arg0, %arg1, %arg3) : (memref<1x32x32x64xf32>, memref<64x3x3x64xf32>, memref<1x32x32x64xf32>) -> () return } func.func @main() { diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index b866916e7cd1..abf36a3e031e 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -4770,90 +4770,62 @@ static func::FuncOp createVerifierFunc(ModuleOp module, const KernelIF &kernel, return func; } -// If the fut expects certain args (mostly output buffers), -// this will populate the linalg.fill calls to do those based -// on the presense of mhal::PrefillAttr. This is to mimic the -// requirement on the kernel launcher to do the same for the -// expected funtionality. +// If the fut expects certain args (mostly output buffers), insert linalg.fill +// calls before each func.call (the clone-harness wrapper invokes the kernel +// via func.call) based on the presence of rock.prefill / mhal.write_access on +// the callee. This mimics the kernel launcher's prefill responsibility. static void insertPrefills(func::FuncOp fut) { SmallVector innerModules; fut->getParentOfType().walk( [&](ModuleOp module) { innerModules.push_back(module); }); innerModules.push_back(fut->getParentOfType()); - fut.walk([&](mhal::LaunchOp launchOp) { - Location loc = launchOp->getLoc(); + + fut.walk([&](func::CallOp callOp) { + Location loc = callOp.getLoc(); DenseMap argInitValues; - StringRef callee = launchOp.getCallee(); - OpBuilder builder(launchOp); + OpBuilder builder(callOp); for (ModuleOp module : innerModules) { - if (func::FuncOp calleeFunc = module.lookupSymbol(callee)) { - size_t argCount = calleeFunc.getArguments().size(); - for (size_t i = 0; i < argCount; i++) { - if (Attribute initAttr = - calleeFunc.getArgAttr(i, rock::PrefillAttr::getMnemonic())) { - argInitValues[i] = initAttr; - } else if (!argInitValues.contains(i) && - calleeFunc.getArgAttr(i, "mhal.write_access")) { - // initialize to 100 by default - // This ensures failure if the output tensor requires prefill, - // helping to detect uninitialized output in GPU vs CPU execution. - auto type = calleeFunc.getArgumentTypes()[i]; - auto elementType = cast(type).getElementType(); - Attribute init; - if (llvm::isa(elementType)) { - init = builder.getFloatAttr(elementType, 100.0); - } else { - assert(llvm::isa(elementType) && - "expecting `int` element type"); - init = builder.getIntegerAttr(elementType, 100); - } - argInitValues[i] = init; + func::FuncOp calleeFunc = + module.lookupSymbol(callOp.getCallee()); + if (!calleeFunc) + continue; + size_t argCount = calleeFunc.getArguments().size(); + for (size_t i = 0; i < argCount; i++) { + if (Attribute initAttr = + calleeFunc.getArgAttr(i, rock::PrefillAttr::getMnemonic())) { + argInitValues[i] = initAttr; + } else if (!argInitValues.contains(i) && + calleeFunc.getArgAttr(i, "mhal.write_access")) { + // Default-initialize write-access outputs to 100 so that any + // position the kernel fails to write differs from CPU's prefilled + // reference, surfacing uninitialized-output bugs. + auto type = calleeFunc.getArgumentTypes()[i]; + auto elementType = cast(type).getElementType(); + Attribute init; + if (llvm::isa(elementType)) { + init = builder.getFloatAttr(elementType, 100.0); + } else { + assert(llvm::isa(elementType) && + "expecting `int` element type"); + init = builder.getIntegerAttr(elementType, 100); } + argInitValues[i] = init; } } } - { - OpBuilder::InsertionGuard guard(builder); - for (auto argIdxAndValueAttr : argInitValues) { - int argIdx = argIdxAndValueAttr.first; - auto valueAttr = argIdxAndValueAttr.second; - auto fillValue = - arith::ConstantOp::create(builder, loc, cast(valueAttr)); - Value originalArg = launchOp.getArgOperands()[argIdx]; - linalg::FillOp::create(builder, loc, ValueRange{fillValue}, - ValueRange{originalArg}); - } + OpBuilder::InsertionGuard guard(builder); + for (auto argIdxAndValueAttr : argInitValues) { + int argIdx = argIdxAndValueAttr.first; + auto valueAttr = argIdxAndValueAttr.second; + auto fillValue = + arith::ConstantOp::create(builder, loc, cast(valueAttr)); + Value originalArg = callOp.getOperands()[argIdx]; + linalg::FillOp::create(builder, loc, ValueRange{fillValue}, + ValueRange{originalArg}); } }); } -// Convert the mhal.launch/mhal.await pattern back to func.call. -static void undoAsyncLaunchPass(Operation *cloneFunc) { - SymbolTableCollection symbolTable; - auto walker = [&](Operation *op) { - OpBuilder builder(op); - if (auto launch = dyn_cast(op)) { - SymbolRefAttr calleeAttr = launch->getAttrOfType("callee"); - CallOpInterface callInt = dyn_cast(op); - assert(callInt); - auto operands = callInt.getArgOperands(); - auto call = func::CallOp::create(builder, op->getLoc(), calleeAttr, - TypeRange{}, operands); - call->moveBefore(op); - op->dropAllUses(); - op->erase(); - return WalkResult::interrupt(); - } - if (auto launch = dyn_cast(op)) { - op->erase(); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }; - while (cloneFunc->walk(walker).wasInterrupted()) { - } -} - static bool isGpuValidationSupported(const GenParams &genParams) { // GPU validation is only supported for conv and gemm kernels return genParams.operation.has_value() && @@ -5008,13 +4980,14 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b, exit(1); } } else { // clone - // Clone the kernel-calling function. xmir-runner will call the appropriate - // binary kernel from the mhal.launch ops; here, we'll replace those with - // func.call which will get the MLIR kernel. No redirection of callees - // needed. - auto *cloneFunc = func->clone(); + // Run prefills before cloning so both the GPU path (kernel binary) and + // the CPU validation path (*_cloned) share identical initial output + // contents. Previously prefills only landed on the GPU path, so + // uninitialized output positions diverged between CPU and GPU and the + // verifier had to tolerate <100% match on kernels that don't fully write + // their outputs (e.g. non-contiguous strides). insertPrefills(static_cast(func)); - undoAsyncLaunchPass(cloneFunc); + auto *cloneFunc = func->clone(); SymbolOpInterface cloneFuncOp = dyn_cast(cloneFunc); SmallString<128> nameBuffer(cloneFuncOp.getName()); nameBuffer += "_cloned"; @@ -5701,11 +5674,9 @@ static void populateCloneHarnessLogic(ModuleOp module) { originalFunc.getFunctionType()); Block *block = wrapperFunc.addEntryBlock(); b.setInsertionPointToStart(block); - auto launchOp = mhal::LaunchOp::create(b, loc, originalFunc, ValueRange{}, - block->getArguments()); - auto results = launchOp->getResults(); - mhal::AwaitOp::create(b, loc, results.front()); - func::ReturnOp::create(b, loc, ValueRange{results.drop_front()}); + auto callOp = + func::CallOp::create(b, loc, originalFunc, block->getArguments()); + func::ReturnOp::create(b, loc, callOp.getResults()); module.push_back(wrapperFunc); auto xmoduleOp = ModuleOp::create(loc, "__xmodule_"); diff --git a/mlir/tools/rocmlir-lib/rocmlir-lib.cpp b/mlir/tools/rocmlir-lib/rocmlir-lib.cpp index 63d14c0471ef..8d84132e9c84 100644 --- a/mlir/tools/rocmlir-lib/rocmlir-lib.cpp +++ b/mlir/tools/rocmlir-lib/rocmlir-lib.cpp @@ -1,4 +1,5 @@ #include "Miir.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MHAL/IR/MHAL.h" #include "mlir/Dialect/Rock/Generator/ConvGenerator.h" #include "mlir/Dialect/Rock/Pipelines/Pipelines.h"