Skip to content

Commit eea2876

Browse files
authored
Skip lowering triton bitcast on tensor of pointers (#250)
We currently only skip lowering tt.bitcast if the input is of triton pointer type. tt.bitcast can also take tensor of pointers, so we need to skip lowering in that case as well.
1 parent 08f1b27 commit eea2876

7 files changed

Lines changed: 40 additions & 3 deletions

File tree

include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "triton-shared/Analysis/OpFoldResultUtils.h"
1313
#include "triton-shared/Analysis/PtrAnalysis.h"
1414
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
15+
#include "triton-shared/Utils/Utils.h"
1516

1617
#include "triton/Dialect/Triton/IR/Dialect.h"
1718

@@ -853,7 +854,7 @@ struct BitcastConverter : public OpConversionPattern<triton::BitcastOp> {
853854
matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor,
854855
ConversionPatternRewriter &rewriter) const override {
855856
// arith::bitcast does not support casting pointers
856-
if (isa<triton::PointerType>(op.getSrc().getType())) {
857+
if (triton::isPtrTypeLike(op.getType())) {
857858
return failure();
858859
}
859860

@@ -1577,7 +1578,7 @@ class ArgMinMaxBaseConverter : public OpConversionPattern<triton::ReduceOp> {
15771578
} else {
15781579
return failure();
15791580
}
1580-
1581+
15811582
auto loc = op.getLoc();
15821583

15831584
auto elemTypes = op.getElementTypes();
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef TRITON_SHARED_UTILITY_H
2+
#define TRITON_SHARED_UTILITY_H
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
// Return true if the input type is a triton pointer or a tensor of triton pointers
9+
bool isPtrTypeLike(Type t);
10+
} // namespace triton
11+
12+
} // namespace mlir
13+
14+
#endif // TRITON_SHARED_UTILITY_H

lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ add_subdirectory(Analysis)
22
add_subdirectory(AnalysisStructured)
33
add_subdirectory(Conversion)
44
add_subdirectory(Dialect)
5+
add_subdirectory(Utils)

lib/Conversion/TritonArithToLinalg/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ add_triton_library(TritonArithToLinalg
1919
TritonTransforms
2020
TritonTilingExtIR
2121
TritonStructuredIR
22+
TritonSharedUtils
2223
)

lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton-shared/Conversion/TritonArithToLinalg/TritonArithToLinalg.h"
1111
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
1212
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
13+
#include "triton-shared/Utils/Utils.h"
1314
#include "triton/Dialect/Triton/IR/Dialect.h"
1415

1516
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -123,7 +124,7 @@ class TritonArithToLinalgPass
123124
target.addLegalOp<triton::FuncOp, triton::ReturnOp>();
124125

125126
target.addDynamicallyLegalOp<triton::BitcastOp>([](triton::BitcastOp op) {
126-
return isa<triton::PointerType>(op.getSrc().getType());
127+
return triton::isPtrTypeLike(op.getType());
127128
});
128129

129130
target.addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect>(

lib/Utils/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_triton_library(TritonSharedUtils
2+
Utils.cpp
3+
4+
LINK_LIBS PUBLIC
5+
TritonIR
6+
)

lib/Utils/Utils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include "triton/Dialect/Triton/IR/Dialect.h"
2+
3+
namespace mlir {
4+
namespace triton {
5+
bool isPtrTypeLike(Type t) {
6+
if (auto tensorType = dyn_cast<RankedTensorType>(t)) {
7+
return isa<triton::PointerType>(tensorType.getElementType());
8+
}
9+
return isa<triton::PointerType>(t);
10+
}
11+
} // namespace triton
12+
13+
} // namespace mlir

0 commit comments

Comments
 (0)