Skip to content

Commit 9c52ead

Browse files
[AIROCMLIR-708] Fix validity masking for padded transform indicies (#2353)
* fix validity masking for padded transform indicies Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com> * refactor validity check and update test IR expectations Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com> * move zeroConst inside addLowerDimBoundsCheck lambda Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
1 parent 1714bad commit 9c52ead

2 files changed

Lines changed: 40 additions & 10 deletions

File tree

mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,15 +1077,19 @@ Value mlir::rock::updateValidityAfter(OpBuilder &b, Location loc,
10771077
Value isValid =
10781078
b.createOrFold<arith::ConstantIntOp>(loc, b.getI1Type(), true);
10791079
ArrayRef<int64_t> lowerBounds = map.getLowerBounds();
1080-
1081-
// unsigned < catches both negatives (as all negatives are > the bound)
1082-
// and being too large on the right.
1083-
auto addLowerDimUltClamp = [&](uint32_t lowerDim) {
1080+
// Explicitly check both bounds. Left padding can produce negative indices,
1081+
// while right padding can produce indices >= bound.
1082+
auto addLowerDimBoundsCheck = [&](uint32_t lowerDim) {
10841083
int64_t bound = lowerBounds[lowerDim];
1084+
Value zeroConst = b.createOrFold<arith::ConstantIndexOp>(loc, 0);
10851085
Value boundConst = b.createOrFold<arith::ConstantIndexOp>(loc, bound);
10861086
Value output = outputs[lowerDim];
1087-
Value inBounds = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ult,
1088-
output, boundConst);
1087+
Value geLowerBound = arith::CmpIOp::create(
1088+
b, loc, arith::CmpIPredicate::sge, output, zeroConst);
1089+
Value ltUpperBound = arith::CmpIOp::create(
1090+
b, loc, arith::CmpIPredicate::slt, output, boundConst);
1091+
Value inBounds = b.createOrFold<arith::AndIOp>(loc, b.getI1Type(),
1092+
geLowerBound, ltUpperBound);
10891093
isValid =
10901094
b.createOrFold<arith::AndIOp>(loc, b.getI1Type(), inBounds, isValid);
10911095
};
@@ -1102,13 +1106,13 @@ Value mlir::rock::updateValidityAfter(OpBuilder &b, Location loc,
11021106

11031107
if (params[leftParam] == 0 && params[rightParam] == 0)
11041108
continue;
1105-
addLowerDimUltClamp(lowerDim);
1109+
addLowerDimBoundsCheck(lowerDim);
11061110
}
11071111
}
11081112
if (type == TransformType::Embed) {
11091113
if (!embedCanBeInvalid(map, op))
11101114
continue;
1111-
addLowerDimUltClamp(op.getLowerDims()[0]);
1115+
addLowerDimBoundsCheck(op.getLowerDims()[0]);
11121116
}
11131117
}
11141118
return isValid;

mlir/test/Dialect/Rock/lowering_transforming_for.mlir

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
by [<Pad{0, 8} ["pad"] at [0] -> ["raw"] at [0]>]
1414
bounds = [16] -> [8]>
1515

16+
#transform_map_pad_left = #rock.transform_map<affine_map<(d0) -> (d0 - 1)>
17+
by [<Pad{1, 7} ["pad"] at [0] -> ["raw"] at [0]>]
18+
bounds = [16] -> [8]>
19+
1620
module {
1721
// CHECK-LABEL: func.func @no_transform_to_affine
1822
func.func @no_transform_to_affine() {
@@ -294,10 +298,13 @@ func.func @no_loop_loop_result(%arg0: index, %arg1: index) -> index {
294298

295299

296300
// CHECK-LABEL: func.func @bounds_check_pad
301+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0
297302
// CHECK-DAG: %[[c8:.*]] = arith.constant 8
298303
// CHECK: affine.for %[[num:.*]] = {{.*}}to 16
299-
// CHECK: %[[valid:.*]] = arith.cmpi ult, %[[num]], %[[c8]]
300-
// CHECK: gpu.printf "%d", %1
304+
// CHECK-DAG: %[[ge:.*]] = arith.cmpi sge, %[[num]], %[[c0]]
305+
// CHECK-DAG: %[[lt:.*]] = arith.cmpi slt, %[[num]], %[[c8]]
306+
// CHECK: %[[valid:.*]] = arith.andi %[[ge]], %[[lt]]
307+
// CHECK: gpu.printf "%d", %{{.*}}
301308
func.func @bounds_check_pad() {
302309
%c0 = arith.constant 0 : index
303310
rock.transforming_for (%arg0) = [#transform_map_pad](%c0) (%arg1) = validity bounds [16] strides [1] {
@@ -306,4 +313,23 @@ func.func @bounds_check_pad() {
306313
}
307314
return
308315
}
316+
317+
// CHECK-LABEL: func.func @bounds_check_pad_left
318+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0
319+
// CHECK-DAG: %[[c8:.*]] = arith.constant 8
320+
// CHECK-DAG: %[[cm1:.*]] = arith.constant -1
321+
// CHECK: affine.for %[[num:.*]] = {{.*}}to 16
322+
// CHECK: %[[shifted:.*]] = arith.addi %[[num]], %[[cm1]]
323+
// CHECK-DAG: %[[ge:.*]] = arith.cmpi sge, %[[shifted]], %[[c0]]
324+
// CHECK-DAG: %[[lt:.*]] = arith.cmpi slt, %[[shifted]], %[[c8]]
325+
// CHECK: %[[valid:.*]] = arith.andi %[[ge]], %[[lt]]
326+
// CHECK: gpu.printf "%d", %{{.*}}
327+
func.func @bounds_check_pad_left() {
328+
%c0 = arith.constant 0 : index
329+
rock.transforming_for (%arg0) = [#transform_map_pad_left](%c0) (%arg1) = validity bounds [16] strides [1] {
330+
%arg1_i32 = arith.extui %arg1 : i1 to i32
331+
gpu.printf "%d", %arg1_i32 : i32
332+
}
333+
return
334+
}
309335
}

0 commit comments

Comments
 (0)