1212#include " mlir/IR/BuiltinTypes.h"
1313#include " mlir/IR/OpDefinition.h"
1414#include " mlir/Transforms/DialectConversion.h"
15+ #include " triton/Dialect/Triton/IR/Dialect.h"
1516
1617namespace mlir {
1718
@@ -74,6 +75,69 @@ SmallVector<Value> ofrsToIndexValues(ArrayRef<OpFoldResult> ofrs,
7475 }));
7576}
7677
78+ Value indexTypeCast (Value v, Type targetTy, const Location loc, OpBuilder &b) {
79+ Type ty = v.getType ();
80+ if (isa<IndexType>(targetTy) || isa<IndexType>(ty)) {
81+ assert ((isa<IntegerType>(targetTy) || isa<IntegerType>(ty)) &&
82+ " Only cast between index type and integer type" );
83+ return b.create <arith::IndexCastOp>(loc, targetTy, v).getResult ();
84+ } else {
85+ auto targetIntTy = cast<IntegerType>(targetTy);
86+ auto intTy = cast<IntegerType>(ty);
87+ if (targetIntTy.getWidth () > intTy.getWidth ())
88+ return b.create <arith::ExtSIOp>(loc, targetTy, v).getResult ();
89+ else
90+ return b.create <arith::TruncIOp>(loc, targetTy, v).getResult ();
91+ }
92+ }
93+
94+ OpFoldResult expandOFRIndex (OpFoldResult ofr, OpFoldResult targetForTy,
95+ const Location loc, OpBuilder &b) {
96+ if (getIntAttr (targetForTy))
97+ return ofr;
98+ Value targetValueForTy = cast<Value>(targetForTy);
99+ Type targetTy = targetValueForTy.getType ();
100+ auto targetShapedTy = dyn_cast<ShapedType>(targetTy);
101+
102+ Value v = dyn_cast<Value>(ofr);
103+ if (!v)
104+ v = b.create <arith::ConstantOp>(loc, cast<IntegerAttr>(cast<Attribute>(ofr)));
105+
106+ Type ty = v.getType ();
107+ if (targetTy == ty)
108+ return ofr;
109+
110+ auto shapedTy = dyn_cast<ShapedType>(ty);
111+ if (targetShapedTy && !shapedTy) {
112+ Type targetEltTy = targetShapedTy.getElementType ();
113+ // cast to target element type first.
114+ if (targetEltTy != ty)
115+ v = indexTypeCast (v, targetEltTy, loc, b);
116+ return b.create <triton::SplatOp>(loc, targetTy, v).getResult ();
117+ } else if (targetShapedTy && shapedTy) {
118+ // TODO: support ShapedType to ShapedType.
119+ Type targetEltTy = targetShapedTy.getElementType ();
120+ Type eltTy = shapedTy.getElementType ();
121+ if (targetShapedTy.getShape () != shapedTy.getShape ())
122+ llvm_unreachable (" ShapedType to ShapedType must have same shape" );
123+ if (isa<IndexType>(targetEltTy) || isa<IndexType>(eltTy)) {
124+ assert ((isa<IntegerType>(targetEltTy) || isa<IntegerType>(eltTy)) &&
125+ " Only cast between index type and integer type" );
126+ return b.create <arith::IndexCastOp>(loc, targetTy, v).getResult ();
127+ } else {
128+ auto targetIntTy = cast<IntegerType>(targetEltTy);
129+ auto intTy = cast<IntegerType>(eltTy);
130+ if (targetIntTy.getWidth () > intTy.getWidth ())
131+ return b.create <arith::ExtSIOp>(loc, targetTy, v).getResult ();
132+ else
133+ return b.create <arith::TruncIOp>(loc, targetTy, v).getResult ();
134+ }
135+ } else {
136+ assert (!shapedTy && " src type rank should be >= target type rank" );
137+ return indexTypeCast (v, targetTy, loc, b);
138+ }
139+ }
140+
77141OpFoldResult addOFRs (const OpFoldResult lhs, const OpFoldResult rhs,
78142 const Location loc, OpBuilder &b) {
79143 auto lhsIntAttr = getIntAttr (lhs);
@@ -95,17 +159,13 @@ OpFoldResult addOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
95159 auto lhsOp =
96160 b.create <arith::ConstantOp>(loc, b.getIndexAttr (lhsIntAttr.value ()));
97161 lhsValue = lhsOp.getResult ();
98- } else {
99- assert (isa<IndexType>(lhsValue.getType ()));
100162 }
101163
102164 auto rhsValue = dyn_cast<Value>(rhs);
103165 if (rhsIntAttr) {
104166 auto rhsOp =
105167 b.create <arith::ConstantOp>(loc, b.getIndexAttr (rhsIntAttr.value ()));
106168 rhsValue = rhsOp.getResult ();
107- } else {
108- assert (isa<IndexType>(lhsValue.getType ()));
109169 }
110170
111171 return b.create <arith::AddIOp>(loc, lhsValue, rhsValue).getResult ();
@@ -143,50 +203,57 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
143203 return sumOp.getResult ();
144204}
145205
146- OpFoldResult mulOFRValue (const OpFoldResult lhs, const Value rhs,
206+ OpFoldResult mulOFRs (const OpFoldResult lhs, const OpFoldResult rhs,
147207 const Location loc, OpBuilder &b) {
148208 auto lhsIntAttr = getIntAttr (lhs);
209+ auto rhsIntAttr = getIntAttr (rhs);
149210
150- auto rhsIsConst = false ;
151- // if rhs is not a const, use max value since min is used to represent
152- // dynamic size or stride
153- auto rhsConstValue = std::numeric_limits<int64_t >::max ();
154- auto rhsOp = rhs.getDefiningOp <arith::ConstantOp>();
155- if (rhsOp) {
156- rhsIsConst = true ;
157- rhsConstValue = cast<IntegerAttr>(rhsOp.getValue ()).getInt ();
211+ auto lhsValue = dyn_cast<Value>(lhs);
212+ if (lhsValue) {
213+ if (auto lhsOp = lhsValue.getDefiningOp <arith::ConstantOp>()) {
214+ lhsIntAttr = cast<IntegerAttr>(lhsOp.getValue ()).getInt ();
215+ }
216+ }
217+ auto rhsValue = dyn_cast<Value>(rhs);
218+ if (rhsValue) {
219+ if (auto rhsOp = rhsValue.getDefiningOp <arith::ConstantOp>()) {
220+ rhsIntAttr = cast<IntegerAttr>(rhsOp.getValue ()).getInt ();
221+ }
158222 }
159223
160- // shortcuts for special cases
224+ // shortcut for special cases
161225 if (lhsIntAttr) {
162226 if (lhsIntAttr.value () == 0 )
163227 return lhs;
164228 if (lhsIntAttr.value () == 1 )
165229 return rhs;
166230 }
167- if (rhsIsConst) {
168- if (rhsConstValue == 0 )
169- return rhsOp.getResult ();
170- if (rhsConstValue == 1 )
231+
232+ if (rhsIntAttr) {
233+ if (rhsIntAttr.value () == 0 )
234+ return rhs;
235+ if (rhsIntAttr.value () == 1 )
171236 return lhs;
172237 }
173238
174- // 0. both lhs and rhs are constants
175- if (lhsIntAttr && rhsIsConst )
176- return b.getIndexAttr (lhsIntAttr.value () * rhsConstValue );
239+ // both lhs and rhs are constants, return result directly
240+ if (lhsIntAttr && rhsIntAttr )
241+ return b.getIndexAttr (lhsIntAttr.value () * rhsIntAttr. value () );
177242
178- // 1. if lhs is constant but rhs is not
179- if (lhsIntAttr && !rhsIsConst ) {
180- auto lhsConstOp =
243+ // otherwise, need to create instructions to calculate new attribute value
244+ if (lhsIntAttr) {
245+ auto lhsOp =
181246 b.create <arith::ConstantOp>(loc, b.getIndexAttr (lhsIntAttr.value ()));
182- auto mulOp = b.create <arith::MulIOp>(loc, lhsConstOp.getResult (), rhs);
183- return mulOp.getResult ();
247+ lhsValue = lhsOp.getResult ();
248+ }
249+
250+ if (rhsIntAttr) {
251+ auto rhsOp =
252+ b.create <arith::ConstantOp>(loc, b.getIndexAttr (rhsIntAttr.value ()));
253+ rhsValue = rhsOp.getResult ();
184254 }
185255
186- // 2. if lhs is not constant
187- assert (!lhsIntAttr);
188- auto mulOp = b.create <arith::MulIOp>(loc, cast<Value>(lhs), rhs);
189- return mulOp.getResult ();
256+ return b.create <arith::MulIOp>(loc, lhsValue, rhsValue).getResult ();
190257}
191258
192259OpFoldResult minOFRs (const OpFoldResult lhs, const OpFoldResult rhs,
0 commit comments