File tree Expand file tree Collapse file tree
Conversion/TritonArithToLinalg
Conversion/TritonArithToLinalg Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1212#include " triton-shared/Analysis/OpFoldResultUtils.h"
1313#include " triton-shared/Analysis/PtrAnalysis.h"
1414#include " triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
15+ #include " triton-shared/Utils/Utils.h"
1516
1617#include " triton/Dialect/Triton/IR/Dialect.h"
1718
@@ -853,7 +854,7 @@ struct BitcastConverter : public OpConversionPattern<triton::BitcastOp> {
853854 matchAndRewrite (triton::BitcastOp op, OpAdaptor adaptor,
854855 ConversionPatternRewriter &rewriter) const override {
855856 // arith::bitcast does not support casting pointers
856- if (isa< triton::PointerType> (op. getSrc () .getType ())) {
857+ if (triton::isPtrTypeLike (op.getType ())) {
857858 return failure ();
858859 }
859860
@@ -1577,7 +1578,7 @@ class ArgMinMaxBaseConverter : public OpConversionPattern<triton::ReduceOp> {
15771578 } else {
15781579 return failure ();
15791580 }
1580-
1581+
15811582 auto loc = op.getLoc ();
15821583
15831584 auto elemTypes = op.getElementTypes ();
Original file line number Diff line number Diff line change 1+ #ifndef TRITON_SHARED_UTILITY_H
2+ #define TRITON_SHARED_UTILITY_H
3+
4+ #include " triton/Dialect/Triton/IR/Dialect.h"
5+
6+ namespace mlir {
7+ namespace triton {
8+ // Return true if the input type is a triton pointer or a tensor of triton pointers
9+ bool isPtrTypeLike (Type t);
10+ } // namespace triton
11+
12+ } // namespace mlir
13+
14+ #endif // TRITON_SHARED_UTILITY_H
Original file line number Diff line number Diff line change @@ -2,3 +2,4 @@ add_subdirectory(Analysis)
22add_subdirectory (AnalysisStructured )
33add_subdirectory (Conversion )
44add_subdirectory (Dialect )
5+ add_subdirectory (Utils )
Original file line number Diff line number Diff line change @@ -19,4 +19,5 @@ add_triton_library(TritonArithToLinalg
1919 TritonTransforms
2020 TritonTilingExtIR
2121 TritonStructuredIR
22+ TritonSharedUtils
2223)
Original file line number Diff line number Diff line change 1010#include " triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h"
1111#include " triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
1212#include " triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
13+ #include " triton-shared/Utils/Utils.h"
1314#include " triton/Dialect/Triton/IR/Dialect.h"
1415
1516#include " mlir/Dialect/Affine/IR/AffineOps.h"
@@ -123,7 +124,7 @@ class TritonArithToLinalgPass
123124 target.addLegalOp <triton::FuncOp, triton::ReturnOp>();
124125
125126 target.addDynamicallyLegalOp <triton::BitcastOp>([](triton::BitcastOp op) {
126- return isa< triton::PointerType> (op. getSrc () .getType ());
127+ return triton::isPtrTypeLike (op.getType ());
127128 });
128129
129130 target.addDynamicallyLegalDialect <arith::ArithDialect, math::MathDialect>(
Original file line number Diff line number Diff line change 1+ add_triton_library (TritonSharedUtils
2+ Utils.cpp
3+
4+ LINK_LIBS PUBLIC
5+ TritonIR
6+ )
Original file line number Diff line number Diff line change 1+ #include " triton/Dialect/Triton/IR/Dialect.h"
2+
3+ namespace mlir {
4+ namespace triton {
5+ bool isPtrTypeLike (Type t) {
6+ if (auto tensorType = dyn_cast<RankedTensorType>(t)) {
7+ return isa<triton::PointerType>(tensorType.getElementType ());
8+ }
9+ return isa<triton::PointerType>(t);
10+ }
11+ } // namespace triton
12+
13+ } // namespace mlir
You can’t perform that action at this time.
0 commit comments