Skip to content

Commit 19fb6f3

Browse files
python3kgaeXiang Li
andauthored
Support for Gather/Scatter access with only one non-continuous dimension (#239)
A new operation, TTS_MakeGatherScatterTensorPtrOp, was added. It is similar to TTS_MakeTensorPtrOp, but includes gather_scatter_dim and gather_scatter_offset to store information about the non-continuous dimension. With TTS_MakeIndirectTensorPtrOp, the continuous memory accesses for the rest of the dimensions are grouped into one operation, which keeps the high-level semantics of the operation intact. In PtrAnalysis: If an operation works on a single dimension, unsupported operations will have PtrState start from the operation with the operation as the indirect offset. For other operations that fail the structured check, reset the dimension using the operation as the indirect offset if it works on a single dimension. When dealing with mulState and addState: For the indirect dimension, set the stride to 1 and use offset * stride as the new offset. If one state has a modulo and the other state is not structured, clear the modulo and use the operand as the offset directly. Also skipped loop case when the PtrState is not structured. --------- Co-authored-by: Xiang Li <xiagli@microsoft.com>
1 parent 7f38361 commit 19fb6f3

10 files changed

Lines changed: 694 additions & 59 deletions

File tree

include/triton-shared/Analysis/OpFoldResultUtils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b);
3535
SmallVector<Value> ofrsToIndexValues(ArrayRef<OpFoldResult> ofrs,
3636
const Location loc, OpBuilder &b);
3737

38+
// Expand index to given type.
39+
OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetOrfForTy,
40+
const Location loc, OpBuilder &b);
41+
3842
// Process addition of two OFRs. If both OFRs are Integer Attributes, result
3943
// is an Integer Attribute. Otherwise, insert the arith.addi instruction if
4044
// needed and use its result Value.
@@ -50,7 +54,7 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
5054
// Process multiplication of two OFRs. If both OFRs are Integer Attributes,
5155
// result is an Integer Attribtue. Otherwise, insert the arith.muli
5256
// instruction if needed and use its result Value.
53-
OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs,
57+
OpFoldResult mulOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
5458
const Location loc, OpBuilder &b);
5559

5660
OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs,

include/triton-shared/AnalysisStructured/PtrAnalysis.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,31 @@ struct PtrState {
5757

5858
bool dimHasModulo(uint32_t dim) const;
5959

60+
bool dimIsStructured(uint32_t dim) const;
61+
int32_t getNonStructuredDim() const;
62+
// When rank is 1, and the only dimension is not continuous.
63+
// There's no dimension is continuous.
64+
bool noStructuredDim() const;
65+
66+
bool isStructured() const;
67+
6068
bool isBlockPtr() const;
6169

6270
void dump() const;
6371

72+
// For unsupported op, save the op to the state.
73+
LogicalResult rebuildAsUnsupportedOp(Value op);
74+
75+
// When merge with other state which is not structured, set the nonContinuous dimension
76+
// offset as op.
77+
// Still need to make sure the op only contribute to nonContinuousDim.
78+
// Fail if the op already mix of different dims.
79+
// For case
80+
// add %remsi(on dim0), %mul(dim1)
81+
// the add will have both dim0 and dim1
82+
// to rebuild use the op, it has to use op[nonContinuousDim] which is not supported.
83+
LogicalResult rebuildAsGatherScatter(Value op, int nonContinuousDim);
84+
6485
// Process addition of two PtrStates.
6586
LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState,
6687
Operation *op, OpBuilder &builder);
@@ -71,6 +92,8 @@ struct PtrState {
7192

7293
tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder,
7394
Location loc);
95+
tts::MakeGatherScatterTensorPtrOp
96+
createTTSMakeGatherScatterTensorPtrOp(OpBuilder &builder, Location loc);
7497
};
7598

7699
class PtrAnalysis {

include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,82 @@ def TTS_MakeTensorPtrOp
120120
//let hasCanonicalizer = 1;
121121
}
122122

123+
def TTS_MakeGatherScatterTensorPtrOp
124+
: TTS_Op<"make_gather_scatter_tptr", [AttrSizedOperandSegments, Pure]> {
125+
// NOTE: Only support cases where the offset for each dimension is defined in a different operation.
126+
// Not support case where the offset is a tensor load from other ptr which for multiple dimension.
127+
//
128+
// offset_m = tl.arange(0, M)
129+
// offset_n = tl.arange(0, N)
130+
// offset_k = tl.arange(0, K)
131+
// ld_offsets = tl.load(a_ptr + offset_m[:,None]+offsets_n[None,:])
132+
// not_support = tl.load(b_ptr + ld_offsets)
133+
// not_support2 = tl.load(b_ptr + ld_offsets * (offset_m[:,None]+offsets_n[None,:]))
134+
// not_support3 = tl.load(b_ptr + (ld_offsets * (offset_m[:,None]+offsets_n[None,:]))[:, :, None] + offset_k[None,None,:])
135+
//
136+
// # Support cases where one dimension is structured while the other is not.
137+
// # For example, `offset_m[:, None] // K` is not structured, whereas `offset_n[None, :]` is structured in next line.
138+
// supported = tl.load(b_ptr + offset_m[:, None] // K + offset_n[None, :])
139+
140+
let summary = "create an pointer that points to a tensor in memory for gather/scatter";
141+
let description = [{
142+
The `tts.make_gather_scatter_tptr` operation is similar to `tts.make_tptr`.
143+
The key difference is that `make_gather_scatter_tptr` accesses the tensor non-continuously.
144+
Currently, only one dimension is allowed to be non-continuous.
145+
This dimension is saved in `gather_scatter_dim`, and the offset for that dimension is saved in `gather_scatter_offset`.
146+
Each contiguous load will load from this offset.
147+
Cases with more than one non-continuous dimension are not supported.
148+
}];
149+
150+
// base: Base pointer used to contruct the tensor of pointers or pointer to tensor.
151+
// gather_scatter_offset: The offset for gather/scatter.
152+
// gather_scatter_dim: The dimension for gather_scatter_offset.
153+
// sizes: Size of the data being loaded or stored.
154+
// strides: The strides of the parent tensor, which means how much to increase the pointer
155+
// by when moving by 1 element in a specific axis.
156+
// offsets: Offset of the block along each dimension from base.
157+
// result: If order is present, this op produces a pointer to a tensor; otherwise,
158+
// it produces a tensor of pointers.
159+
160+
let arguments = (ins TT_Ptr:$base,
161+
I32Tensor:$gather_scatter_offset,
162+
I32Attr:$gather_scatter_dim,
163+
DenseI64ArrayAttr:$sizes,
164+
Variadic<Index>:$strides,
165+
Variadic<Index>:$offsets,
166+
DenseI64ArrayAttr:$static_strides,
167+
DenseI64ArrayAttr:$static_offsets);
168+
169+
let results = (outs TT_PtrLike:$result);
170+
171+
let assemblyFormat = [{
172+
$base `to` `sizes` `` `:` $sizes
173+
`gather_scatter_dim` `` `:` $gather_scatter_dim
174+
`gather_scatter_offset` `` `:` $gather_scatter_offset
175+
`` `,` `strides` `` `:`
176+
custom<DynamicIndexList>($strides, $static_strides)
177+
`` `,` `offsets` `` `:`
178+
custom<DynamicIndexList>($offsets, $static_offsets)
179+
attr-dict `:` type($gather_scatter_offset) type($base) `to` type($result)
180+
}];
181+
182+
183+
let builders = [
184+
// Build with mixed static and dynamic entries.
185+
OpBuilder<(ins
186+
"Value":$base,
187+
"Value":$gather_scatter_offset,
188+
"int":$gather_scatter_dim,
189+
"ArrayRef<int64_t>":$sizes,
190+
"ArrayRef<OpFoldResult>":$strides,
191+
"ArrayRef<OpFoldResult>":$offsets)>,
192+
];
193+
194+
// TODO
195+
//let hasVerifier = 1;
196+
//let hasCanonicalizer = 1;
197+
}
198+
123199
def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> {
124200
let summary = "Placeholder for the structured pointer states computed during PtrAnalysis.";
125201
let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites.";

lib/Analysis/OpFoldResultUtils.cpp

Lines changed: 97 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/BuiltinTypes.h"
1313
#include "mlir/IR/OpDefinition.h"
1414
#include "mlir/Transforms/DialectConversion.h"
15+
#include "triton/Dialect/Triton/IR/Dialect.h"
1516

1617
namespace mlir {
1718

@@ -74,6 +75,69 @@ SmallVector<Value> ofrsToIndexValues(ArrayRef<OpFoldResult> ofrs,
7475
}));
7576
}
7677

78+
Value indexTypeCast(Value v, Type targetTy, const Location loc, OpBuilder &b) {
79+
Type ty = v.getType();
80+
if (isa<IndexType>(targetTy) || isa<IndexType>(ty)) {
81+
assert((isa<IntegerType>(targetTy) || isa<IntegerType>(ty)) &&
82+
"Only cast between index type and integer type");
83+
return b.create<arith::IndexCastOp>(loc, targetTy, v).getResult();
84+
} else {
85+
auto targetIntTy = cast<IntegerType>(targetTy);
86+
auto intTy = cast<IntegerType>(ty);
87+
if (targetIntTy.getWidth() > intTy.getWidth())
88+
return b.create<arith::ExtSIOp>(loc, targetTy, v).getResult();
89+
else
90+
return b.create<arith::TruncIOp>(loc, targetTy, v).getResult();
91+
}
92+
}
93+
94+
OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
95+
const Location loc, OpBuilder &b) {
96+
if (getIntAttr(targetForTy))
97+
return ofr;
98+
Value targetValueForTy = cast<Value>(targetForTy);
99+
Type targetTy = targetValueForTy.getType();
100+
auto targetShapedTy = dyn_cast<ShapedType>(targetTy);
101+
102+
Value v = dyn_cast<Value>(ofr);
103+
if (!v)
104+
v = b.create<arith::ConstantOp>(loc, cast<IntegerAttr>(cast<Attribute>(ofr)));
105+
106+
Type ty = v.getType();
107+
if (targetTy == ty)
108+
return ofr;
109+
110+
auto shapedTy = dyn_cast<ShapedType>(ty);
111+
if (targetShapedTy && !shapedTy) {
112+
Type targetEltTy = targetShapedTy.getElementType();
113+
// cast to target element type first.
114+
if (targetEltTy != ty)
115+
v = indexTypeCast(v, targetEltTy, loc, b);
116+
return b.create<triton::SplatOp>(loc, targetTy, v).getResult();
117+
} else if (targetShapedTy && shapedTy) {
118+
// TODO: support ShapedType to ShapedType.
119+
Type targetEltTy = targetShapedTy.getElementType();
120+
Type eltTy = shapedTy.getElementType();
121+
if (targetShapedTy.getShape() != shapedTy.getShape())
122+
llvm_unreachable("ShapedType to ShapedType must have same shape");
123+
if (isa<IndexType>(targetEltTy) || isa<IndexType>(eltTy)) {
124+
assert((isa<IntegerType>(targetEltTy) || isa<IntegerType>(eltTy)) &&
125+
"Only cast between index type and integer type");
126+
return b.create<arith::IndexCastOp>(loc, targetTy, v).getResult();
127+
} else {
128+
auto targetIntTy = cast<IntegerType>(targetEltTy);
129+
auto intTy = cast<IntegerType>(eltTy);
130+
if (targetIntTy.getWidth() > intTy.getWidth())
131+
return b.create<arith::ExtSIOp>(loc, targetTy, v).getResult();
132+
else
133+
return b.create<arith::TruncIOp>(loc, targetTy, v).getResult();
134+
}
135+
} else {
136+
assert(!shapedTy && "src type rank should be >= target type rank");
137+
return indexTypeCast(v, targetTy, loc, b);
138+
}
139+
}
140+
77141
OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
78142
const Location loc, OpBuilder &b) {
79143
auto lhsIntAttr = getIntAttr(lhs);
@@ -95,17 +159,13 @@ OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
95159
auto lhsOp =
96160
b.create<arith::ConstantOp>(loc, b.getIndexAttr(lhsIntAttr.value()));
97161
lhsValue = lhsOp.getResult();
98-
} else {
99-
assert(isa<IndexType>(lhsValue.getType()));
100162
}
101163

102164
auto rhsValue = dyn_cast<Value>(rhs);
103165
if (rhsIntAttr) {
104166
auto rhsOp =
105167
b.create<arith::ConstantOp>(loc, b.getIndexAttr(rhsIntAttr.value()));
106168
rhsValue = rhsOp.getResult();
107-
} else {
108-
assert(isa<IndexType>(lhsValue.getType()));
109169
}
110170

111171
return b.create<arith::AddIOp>(loc, lhsValue, rhsValue).getResult();
@@ -143,50 +203,57 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
143203
return sumOp.getResult();
144204
}
145205

146-
OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs,
206+
OpFoldResult mulOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
147207
const Location loc, OpBuilder &b) {
148208
auto lhsIntAttr = getIntAttr(lhs);
209+
auto rhsIntAttr = getIntAttr(rhs);
149210

150-
auto rhsIsConst = false;
151-
// if rhs is not a const, use max value since min is used to represent
152-
// dynamic size or stride
153-
auto rhsConstValue = std::numeric_limits<int64_t>::max();
154-
auto rhsOp = rhs.getDefiningOp<arith::ConstantOp>();
155-
if (rhsOp) {
156-
rhsIsConst = true;
157-
rhsConstValue = cast<IntegerAttr>(rhsOp.getValue()).getInt();
211+
auto lhsValue = dyn_cast<Value>(lhs);
212+
if (lhsValue) {
213+
if (auto lhsOp = lhsValue.getDefiningOp<arith::ConstantOp>()) {
214+
lhsIntAttr = cast<IntegerAttr>(lhsOp.getValue()).getInt();
215+
}
216+
}
217+
auto rhsValue = dyn_cast<Value>(rhs);
218+
if (rhsValue) {
219+
if (auto rhsOp = rhsValue.getDefiningOp<arith::ConstantOp>()) {
220+
rhsIntAttr = cast<IntegerAttr>(rhsOp.getValue()).getInt();
221+
}
158222
}
159223

160-
// shortcuts for special cases
224+
// shortcut for special cases
161225
if (lhsIntAttr) {
162226
if (lhsIntAttr.value() == 0)
163227
return lhs;
164228
if (lhsIntAttr.value() == 1)
165229
return rhs;
166230
}
167-
if (rhsIsConst) {
168-
if (rhsConstValue == 0)
169-
return rhsOp.getResult();
170-
if (rhsConstValue == 1)
231+
232+
if (rhsIntAttr) {
233+
if (rhsIntAttr.value() == 0)
234+
return rhs;
235+
if (rhsIntAttr.value() == 1)
171236
return lhs;
172237
}
173238

174-
// 0. both lhs and rhs are constants
175-
if (lhsIntAttr && rhsIsConst)
176-
return b.getIndexAttr(lhsIntAttr.value() * rhsConstValue);
239+
// both lhs and rhs are constants, return result directly
240+
if (lhsIntAttr && rhsIntAttr)
241+
return b.getIndexAttr(lhsIntAttr.value() * rhsIntAttr.value());
177242

178-
// 1. if lhs is constant but rhs is not
179-
if (lhsIntAttr && !rhsIsConst) {
180-
auto lhsConstOp =
243+
// otherwise, need to create instructions to calculate new attribute value
244+
if (lhsIntAttr) {
245+
auto lhsOp =
181246
b.create<arith::ConstantOp>(loc, b.getIndexAttr(lhsIntAttr.value()));
182-
auto mulOp = b.create<arith::MulIOp>(loc, lhsConstOp.getResult(), rhs);
183-
return mulOp.getResult();
247+
lhsValue = lhsOp.getResult();
248+
}
249+
250+
if (rhsIntAttr) {
251+
auto rhsOp =
252+
b.create<arith::ConstantOp>(loc, b.getIndexAttr(rhsIntAttr.value()));
253+
rhsValue = rhsOp.getResult();
184254
}
185255

186-
// 2. if lhs is not constant
187-
assert(!lhsIntAttr);
188-
auto mulOp = b.create<arith::MulIOp>(loc, cast<Value>(lhs), rhs);
189-
return mulOp.getResult();
256+
return b.create<arith::MulIOp>(loc, lhsValue, rhsValue).getResult();
190257
}
191258

192259
OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs,

lib/Analysis/PtrAnalysis.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ void PtrState::mulState(const PtrState &lhsState, const PtrState &rhsState,
130130

131131
for (uint64_t i = 0; i < lhs->sizes.size(); i++) {
132132
OpFoldResult newOffset =
133-
mulOFRValue(lhs->offsets[i], rhs->scalar, loc, rewriter);
133+
mulOFRs(lhs->offsets[i], rhs->scalar, loc, rewriter);
134134
OpFoldResult newStride =
135-
mulOFRValue(lhs->strides[i], rhs->scalar, loc, rewriter);
135+
mulOFRs(lhs->strides[i], rhs->scalar, loc, rewriter);
136136
offsets.push_back(newOffset);
137137
strides.push_back(newStride);
138138
sizes.push_back(lhs->sizes[i]);

0 commit comments

Comments
 (0)