Skip to content

Commit 08f1b27

Browse files
authored
Add TPtrDialect (#253)
This PR introduces TPTrDialect. This is a temporary dialect that includes some of the ops that operate on the `!ptr.ptr` type in the proposal here: https://discourse.llvm.org/t/rfc-ptr-dialect-modularizing-ptr-ops-in-the-llvm-dialect/75142. We will leverage this dialect to lower the remaining pointer operations that cannot be handled by `triton-to-structured` and `triton-to-unstructured` passes. Once the proposal is fully implemented, we will migrate over to using the official ops instead of this temporary dialect. The code is borrowed from the author's draft implementation here with some simplifications: + https://github.com/fabianmcg/mlir-address/blob/main/include/mlir/Dialect/Address/IR/AddressOps.td + https://github.com/llvm/llvm-project/pull/73057/files
1 parent 18e9ae7 commit 08f1b27

11 files changed

Lines changed: 345 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(TritonTilingExt)
22
add_subdirectory(TritonStructured)
3+
add_subdirectory(TPtr)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
set(LLVM_TARGET_DEFINITIONS TPtrDialect.td)
2+
mlir_tablegen(TPtrDialect.h.inc -gen-dialect-decls -dialect=tptr)
3+
mlir_tablegen(TPtrDialect.cpp.inc -gen-dialect-defs -dialect=tptr)
4+
mlir_tablegen(TPtrOps.h.inc -gen-op-decls)
5+
mlir_tablegen(TPtrOps.cpp.inc -gen-op-defs)
6+
7+
set(LLVM_TARGET_DEFINITIONS TPtrDialect.td)
8+
mlir_tablegen(TPtrTypes.h.inc -gen-typedef-decls -typedefs-dialect=tptr)
9+
mlir_tablegen(TPtrTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=tptr)
10+
11+
add_public_tablegen_target(TPtrTableGen)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef MLIR_DIALECT_TPTR_IR_TPTR_DIALECT_H_
2+
#define MLIR_DIALECT_TPTR_IR_TPTR_DIALECT_H_
3+
4+
#include "mlir/Interfaces/SideEffectInterfaces.h" // Required for IR/TPtrOps.h.inc
5+
#include "mlir/Bytecode/BytecodeOpInterface.h"
6+
7+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h" // Required for IR/TPtrOps.h.inc
8+
#include "mlir/Dialect/Ptr/IR/PtrTypes.h" // Required for IR/TPtrOps.h.inc
9+
10+
//===----------------------------------------------------------------------===//
11+
// Temporary Pointer Dialect Operations
12+
//===----------------------------------------------------------------------===//
13+
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h.inc"
14+
15+
// Include the auto-generated header file containing the declarations of the
16+
// Temporary Pointer Dialect operations.
17+
#define GET_OP_CLASSES
18+
#include "triton-shared/Dialect/TPtr/IR/TPtrOps.h.inc"
19+
20+
#define GET_TYPEDEF_CLASSES
21+
#include "triton-shared/Dialect/TPtr/IR/TPtrTypes.h.inc"
22+
23+
#endif
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#ifndef TPTR_DIALECT
2+
#define TPTR_DIALECT
3+
4+
include "mlir/IR/OpBase.td"
5+
include "mlir/Interfaces/SideEffectInterfaces.td"
6+
include "mlir/Dialect/Ptr/IR/PtrDialect.td"
7+
include "mlir/IR/AttrTypeBase.td"
8+
include "mlir/IR/BuiltinTypeInterfaces.td"
9+
10+
def TPtr_Dialect : Dialect {
11+
let name = "tptr";
12+
13+
let cppNamespace = "::mlir::tptr";
14+
15+
let summary = "Temporary Pointer Dialect";
16+
17+
let description = [{
18+
Typed Pointer Dialect.
19+
}];
20+
21+
let extraClassDeclaration = [{
22+
void registerTypes();
23+
}];
24+
25+
let dependentDialects = [
26+
"mlir::ptr::PtrDialect"
27+
];
28+
29+
let usePropertiesForAttributes = 1;
30+
}
31+
32+
class TPtrTypeDef<string name, string _mnemonic, list<Trait> traits = []>
33+
: TypeDef<TPtr_Dialect, name, traits> {
34+
// Used by printer/parser
35+
let mnemonic = _mnemonic;
36+
}
37+
38+
//
39+
// Op Base
40+
//
41+
class TPTR_Op<string mnemonic, list<Trait> traits = []> :
42+
Op<TPtr_Dialect, mnemonic, traits> {
43+
}
44+
45+
def TPTR_IntToPtrOp : TPTR_Op<"inttoptr", [
46+
Pure
47+
]> {
48+
let summary = "Integer to a pointer operation";
49+
let description = [{
50+
The `inttoptr` operation casts an int or index value to a pointer.
51+
52+
Example:
53+
```mlir
54+
%ptr = ptr.inttoptr %int : i32 to !ptr.ptr<1 : i32>
55+
```
56+
}];
57+
let arguments = (ins AnySignlessIntegerOrIndex:$arg);
58+
let results = (outs Ptr_PtrType:$res);
59+
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
60+
}
61+
62+
def TPTR_PtrToIntOp : TPTR_Op<"ptrtoint", [
63+
Pure
64+
]> {
65+
let summary = "Pointer to an integer operation";
66+
let description = [{
67+
The `ptrtoint` operation casts a pointer value to an int or index.
68+
69+
Example:
70+
```mlir
71+
%int = ptr.ptrtoint %ptr : !ptr.ptr<1 : i32> to i32
72+
```
73+
}];
74+
let arguments = (ins Ptr_PtrType:$arg);
75+
let results = (outs AnySignlessIntegerOrIndex:$res);
76+
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
77+
}
78+
79+
def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
80+
let summary = "Creates a type offset constant.";
81+
let description = [{
82+
The `addr.type_offset` operation produces an int or index-typed SSA value
83+
equal to a target-specific constant representing the offset of a single
84+
element of the given type. The default return type is `index`.
85+
Example:
86+
87+
```mlir
88+
%0 = addr.type_offset f32
89+
%1 = addr.type_offset memref<12 x f64> : i32
90+
```
91+
}];
92+
93+
let arguments = (ins TypeAttr:$baseType);
94+
let results = (outs AnySignlessIntegerOrIndex:$result);
95+
let builders = [
96+
OpBuilder<(ins "TypeAttr":$baseType, CArg<"Type", "nullptr">:$resultTy)>
97+
];
98+
let assemblyFormat = [{
99+
attr-dict $baseType custom<IntType>(type($result))
100+
}];
101+
let hasFolder = 1;
102+
}
103+
104+
def TPTR_FromMemrefOp : TPTR_Op<"from_memref", [Pure]> {
105+
let arguments = (ins AnyMemRef:$input);
106+
let results = (outs Ptr_PtrType:$result);
107+
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
108+
}
109+
110+
def TPTR_ToMemrefOp : TPTR_Op<"to_memref", [
111+
Pure ]> {
112+
let arguments = (ins Ptr_PtrType:$arg);
113+
let results = (outs AnyStaticShapeMemRef:$res);
114+
let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)";
115+
}
116+
117+
def TPTR_PtrAddOp : TPTR_Op<"ptradd", [Pure, AllTypesMatch<["base", "result"]>]> {
118+
let summary = "Pointer-index add operation";
119+
let description = [{
120+
The `ptradd` operation adds an `address` and an integer or index to
121+
produce a new address.
122+
123+
Example:
124+
```mlir
125+
%addr = ptr.ptradd %addr : !ptr.ptr<3 : i32>, %c10 : i32
126+
```
127+
}];
128+
129+
let arguments = (ins Ptr_PtrType:$base, AnySignlessIntegerOrIndex:$offset);
130+
let results = (outs Ptr_PtrType:$result);
131+
let assemblyFormat = "$base $offset attr-dict `:` type($base) `,` type($offset) `to` type($result)";
132+
}
133+
134+
def TPTR_LoadOp : TPTR_Op<"load", [
135+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
136+
]> {
137+
let summary = "Load operation";
138+
let description = [{
139+
The `load` operation is used to read from memory. A load may be marked as
140+
atomic, volatile, and/or nontemporal, and takes a number of optional
141+
attributes that specify aliasing information.
142+
143+
An atomic load only supports a limited set of pointer, integer, and
144+
floating point types, and requires an explicit alignment.
145+
146+
Examples:
147+
```mlir
148+
// A volatile load of a float variable.
149+
%0 = ptr.load volatile %ptr : !ptr.ptr -> f32
150+
151+
// A nontemporal load of a float variable.
152+
%0 = ptr.load %ptr {nontemporal} : !ptr.ptr -> f32
153+
154+
// An atomic load of an integer variable.
155+
%0 = ptr.load %ptr atomic monotonic {alignment = 8 : i64}
156+
: !ptr.ptr -> i64
157+
```
158+
}];
159+
let arguments = (ins AnyType:$addr);
160+
let results = (outs AnyType:$res);
161+
let assemblyFormat = [{
162+
$addr
163+
attr-dict `:` qualified(type($addr)) `->` type($res)
164+
}];
165+
}
166+
167+
def TTPTR_StoreOp : TPTR_Op<"store", [
168+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
169+
]> {
170+
let summary = "Store operation";
171+
let description = [{
172+
The `store` operation is used to write to memory. A store may be marked as
173+
atomic, volatile, and/or nontemporal, and takes a number of optional
174+
attributes that specify aliasing information.
175+
176+
An atomic store only supports a limited set of pointer, integer, and
177+
floating point types, and requires an explicit alignment.
178+
179+
Examples:
180+
```mlir
181+
// A volatile store of a float variable.
182+
ptr.store volatile %val, %ptr : f32, !ptr.ptr
183+
184+
// A nontemporal store of a float variable.
185+
ptr.store %val, %ptr {nontemporal} : f32, !ptr.ptr
186+
187+
// An atomic store of an integer variable.
188+
ptr.store %val, %ptr atomic monotonic {alignment = 8 : i64}
189+
: i64, !ptr.ptr
190+
```
191+
}];
192+
let arguments = (ins AnyType:$value,
193+
AnyType:$addr);
194+
let assemblyFormat = [{
195+
$value `,` $addr
196+
attr-dict `:` type($value) `,` qualified(type($addr))
197+
}];
198+
}
199+
200+
#endif // TPTR_DIALECT

lib/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(TritonTilingExt)
22
add_subdirectory(TritonStructured)
3+
add_subdirectory(TPtr)

lib/Dialect/TPtr/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)

lib/Dialect/TPtr/IR/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
add_triton_library(TPtrIR
2+
TPtrOps.cpp
3+
TPtrDialect.cpp
4+
5+
DEPENDS
6+
TPtrTableGen
7+
8+
LINK_LIBS PUBLIC
9+
TritonIR
10+
MLIRIR
11+
MLIRPtrDialect
12+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include "mlir/IR/Builders.h"
2+
3+
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
4+
5+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
6+
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
7+
8+
#define GET_TYPEDEF_CLASSES
9+
#include "triton-shared/Dialect/TPtr/IR/TPtrTypes.cpp.inc"
10+
11+
using namespace mlir;
12+
13+
namespace {
14+
ParseResult parseIntType(OpAsmParser &parser, Type &ty) {
15+
if (succeeded(parser.parseOptionalColon()) && parser.parseType(ty))
16+
return parser.emitError(parser.getNameLoc(), "expected a type");
17+
if (!ty)
18+
ty = parser.getBuilder().getIndexType();
19+
return success();
20+
}
21+
void printIntType(OpAsmPrinter &p, Operation *op, Type ty) {
22+
if (!ty.isIndex())
23+
p << " : " << ty;
24+
}
25+
} // namespace
26+
27+
//===----------------------------------------------------------------------===//
28+
// Dialect
29+
//===----------------------------------------------------------------------===//
30+
void mlir::tptr::TPtrDialect::registerTypes() {
31+
addTypes<
32+
#define GET_TYPEDEF_LIST
33+
#include "triton-shared/Dialect/TPtr/IR/TPtrTypes.cpp.inc"
34+
>();
35+
}
36+
37+
/// Dialect creation, the instance will be owned by the context. This is the
38+
/// point of registration of custom types and operations for the dialect.
39+
void mlir::tptr::TPtrDialect::initialize() {
40+
registerTypes();
41+
addOperations<
42+
#define GET_OP_LIST
43+
#include "triton-shared/Dialect/TPtr/IR/TPtrOps.cpp.inc"
44+
>();
45+
}
46+
47+
//===----------------------------------------------------------------------===//
48+
// TableGen'd op method definitions
49+
//===----------------------------------------------------------------------===//
50+
51+
#define GET_OP_CLASSES
52+
#include "triton-shared/Dialect/TPtr/IR/TPtrOps.cpp.inc"
53+
54+
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.cpp.inc"

lib/Dialect/TPtr/IR/TPtrOps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "mlir/Interfaces/SideEffectInterfaces.h" // Required for IR/TPtrOps.h.inc
2+
#include "mlir/Bytecode/BytecodeOpInterface.h"
3+
4+
#include "mlir/IR/OpImplementation.h"
5+
#include "mlir/IR/Builders.h"
6+
#include "mlir/IR/BuiltinAttributes.h"
7+
#include "mlir/IR/BuiltinTypes.h"
8+
#include "mlir/IR/MLIRContext.h"
9+
#include "mlir/IR/OperationSupport.h"
10+
#include "mlir/IR/OpDefinition.h"
11+
#include "mlir/IR/Dialect.h"
12+
13+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
14+
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
15+
16+
#define GET_OP_CLASSES
17+
#include "triton-shared/Dialect/TPtr/IR/TPtrOps.h.inc"
18+
19+
using namespace mlir;
20+
using namespace mlir::tptr;
21+
22+
void LoadOp::getEffects(
23+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
24+
&effects) {
25+
effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable(),
26+
SideEffects::DefaultResource::get());
27+
}
28+
29+
void StoreOp::getEffects(
30+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
31+
&effects) {
32+
effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable(),
33+
SideEffects::DefaultResource::get());
34+
}
35+
36+
OpFoldResult TypeOffsetOp::fold(FoldAdaptor adaptor) {
37+
return adaptor.getBaseTypeAttr();
38+
}

0 commit comments

Comments
 (0)