Skip to content

Commit 8036b43

Browse files
python3kgaeXiang Li
andauthored
Support stride for unstructured PtrState. (#289)
Allow stride for unstructured PtrState to be value other than 1. This happens when the PtrState finished with mul operations like ``` x_ptr + offsets_m[:, None] * stride_m + offsets_n[None, :] * stride_n ``` The mulState is changed for unstructured dim to only mul the scalar to stride. The addState logic is simplified. If one PtrState is not for current dim, just use the other PtrState. If strides are equal for 2 PtrStates like ``` lhs_offset * stride + rhs_offset * stride ``` treat it like (lhs_offset + rhs_offset) * stride by only add offsets. If offsets are equal for 2 PtrStates like ``` offset * lhs_stride + offset * rhs_stride ``` treat it like offset * (lhs_stride + rhs_stride) by only add strides. If both offsets and strides are not equal, change offset * stride into (offset * stride) * 1 by set new offset as offset * stride and new stride as 1. Then two PtrState will have same stride and go the stride equal path. The tts.make_gather_scatter_tptr will have correct strides with this change. --------- Co-authored-by: Xiang Li <xiagli@microsoft.com>
1 parent e1ca305 commit 8036b43

10 files changed

Lines changed: 213 additions & 140 deletions

File tree

lib/AnalysisStructured/PtrAnalysis.cpp

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,16 @@ LogicalResult PtrState::rebuildAsUnsupportedOp(Value operand) {
145145
// Setup state for unsupported operation.
146146
auto indexTy = IndexType::get(operand.getContext());
147147
auto index0 = IntegerAttr::get(indexTy, APInt(64, 0));
148+
auto index1 = IntegerAttr::get(indexTy, APInt(64, 1));
148149
for (auto size : opShape) {
149-
if (size == 1)
150+
if (size == 1) {
150151
offsets.push_back(index0);
151-
else
152+
strides.push_back(index0);
153+
} else {
152154
offsets.push_back(operand);
155+
strides.push_back(index1);
156+
}
153157
sizes.push_back(IntegerAttr::get(indexTy, APInt(64, size)));
154-
strides.push_back(index0);
155158
shape.push_back(index0);
156159
}
157160
return success();
@@ -174,9 +177,10 @@ LogicalResult PtrState::rebuildAsGatherScatter(Value op, int nonContinuousDim) {
174177
// Setup state for nonContinuousDim.
175178
auto indexTy = IndexType::get(op.getContext());
176179
auto index0 = IntegerAttr::get(indexTy, APInt(64, 0));
180+
auto index1 = IntegerAttr::get(indexTy, APInt(64, 1));
177181

178182
offsets[nonContinuousDim] = op;
179-
strides[nonContinuousDim] = index0;
183+
strides[nonContinuousDim] = index1;
180184
shape[nonContinuousDim] = index0;
181185
return success();
182186
}
@@ -222,43 +226,105 @@ LogicalResult PtrState::addState(const PtrState &lhsState,
222226
addOFRs(lhsState.strides[i], rhsState.strides[i], loc, builder);
223227
strides.push_back(newStride);
224228
} else {
225-
// Set stride to 1 when not continuous.
226-
strides.push_back(builder.getIndexAttr(1));
227-
// New offset is offset * stride.
228-
auto newLhsOffset = lhsState.offsets[i];
229-
auto newRhsOffset = rhsState.offsets[i];
230229
if (isAnalysisingUnstructured) {
231230
assert(!lhsState.hasModulo() && !rhsState.hasModulo() &&
232231
"should not have dimension with modulo when analysing "
233232
"unstructured");
234-
// When the dimension is structured, mul the offset by the stride to
235-
// match the stride 1 for non-structured dimensions.
236-
// If the dimension is not structured, the offset is already multiplied
237-
// by the stride.
238-
// If stride is 0 which will happen after
239-
// visitOperandExpandDims/visitOperandSplat, we cannot mul which will
240-
// get zero and lost the offset.
241-
if (lhsState.dimIsStructured(i) && !hasConstZero(lhsState.strides[i])) {
242-
auto stride = expandOFRIndex(lhsState.strides[i], lhsState.offsets[i],
243-
loc, builder);
244-
newLhsOffset = mulOFRs(lhsState.offsets[i], stride, loc, builder);
245-
}
246-
if (rhsState.dimIsStructured(i) && !hasConstZero(rhsState.strides[i])) {
247-
auto stride = expandOFRIndex(rhsState.strides[i], rhsState.offsets[i],
248-
loc, builder);
249-
newRhsOffset = mulOFRs(rhsState.offsets[i], stride, loc, builder);
250-
}
251-
// Make sure newLhsOffset and newRhsOffset get same type.
252-
if (!lhsState.dimIsStructured(i)) {
253-
newRhsOffset =
254-
expandOFRIndex(newRhsOffset, newLhsOffset, loc, builder);
233+
if (hasConstZero(lhsState.strides[i]) &&
234+
hasConstZero(lhsState.offsets[i])) {
235+
// If lhs is not for dim i, we can just use rhs's stride and offset.
236+
offsets.push_back(rhsState.offsets[i]);
237+
strides.push_back(rhsState.strides[i]);
238+
} else if (hasConstZero(rhsState.strides[i]) &&
239+
hasConstZero(rhsState.offsets[i])) {
240+
// If rhs is not for dim i, we can just use lhs's stride and offset.
241+
offsets.push_back(lhsState.offsets[i]);
242+
strides.push_back(lhsState.strides[i]);
255243
} else {
256-
newLhsOffset =
257-
expandOFRIndex(newLhsOffset, newRhsOffset, loc, builder);
244+
OpFoldResult lhsOffset = lhsState.offsets[i];
245+
OpFoldResult rhsOffset = rhsState.offsets[i];
246+
OpFoldResult lhsStride = lhsState.strides[i];
247+
OpFoldResult rhsStride = rhsState.strides[i];
248+
// If stride is 0 which will happen after
249+
// visitOperandExpandDims/visitOperandSplat, we set the stride to 1 to
250+
// mul it with offset.
251+
if (hasConstZero(lhsStride)) {
252+
assert(lhsState.dimIsStructured(i) &&
253+
!rhsState.dimIsStructured(i) &&
254+
"If lhs stride is zero, it must be structured and rhs "
255+
"stride is unstructured");
256+
lhsStride = builder.getIndexAttr(1);
257+
}
258+
if (hasConstZero(rhsStride)) {
259+
assert(rhsState.dimIsStructured(i) &&
260+
!lhsState.dimIsStructured(i) &&
261+
"If rhs stride is zero, it must be structured and lhs "
262+
"stride is unstructured");
263+
rhsStride = builder.getIndexAttr(1);
264+
}
265+
266+
// If both offset and stride not equal, we merge 2 PtrStates by change
267+
// offset * stride into (offset * stride) * 1 where new offset is
268+
// offset * stride and new stride is set to 1.
269+
// Then we'll have strides equal as 1, and merge them as PtrState with
270+
// same strides.
271+
if (lhsOffset != rhsOffset && lhsStride != rhsStride) {
272+
// Expand offset since unstructured offset has tensor type.
273+
OpFoldResult stride =
274+
expandOFRIndex(lhsStride, lhsOffset, loc, builder);
275+
// new offset = offset * stride
276+
lhsOffset = mulOFRs(lhsOffset, stride, loc, builder);
277+
// Expand offset since unstructured offset has tensor type.
278+
stride = expandOFRIndex(rhsStride, rhsOffset, loc, builder);
279+
// new offset = offset * stride
280+
rhsOffset = mulOFRs(rhsOffset, stride, loc, builder);
281+
// Set both strides to 1.
282+
lhsStride = builder.getIndexAttr(1);
283+
rhsStride = builder.getIndexAttr(1);
284+
}
285+
286+
if (lhsStride == rhsStride) {
287+
// For case like lhs_offset * stride + rhs_offset * stride, it is same as
288+
// (lhs_offset + rhs_offset) * stride.
289+
// We can just
290+
// add the offsets and reuse the stride like this:
291+
// offsets[i] = lhsOffset + rhsOffset
292+
// strides[i] = lhsStride
293+
// Expand structured offset since unstructured offset has tensor type.
294+
if (!lhsState.dimIsStructured(i)) {
295+
rhsOffset = expandOFRIndex(rhsOffset, lhsOffset, loc, builder);
296+
} else {
297+
lhsOffset = expandOFRIndex(lhsOffset, rhsOffset, loc, builder);
298+
}
299+
// Add offsets.
300+
offsets.push_back(addOFRs(lhsOffset, rhsOffset, loc, builder));
301+
// Reuse stride.
302+
strides.push_back(lhsStride);
303+
} else {
304+
// Assert that offsets are equal if strides are not equal.
305+
// This is because we are already forcing the strides to be
306+
// equal to 1 earlier for case both offsets and strides not equal.
307+
assert(lhsOffset == rhsOffset &&
308+
"If strides are not equal, offsets must be equal");
309+
// For case like offset * lhs_stride + offset * rhs_stride, it is same as
310+
// offset * (lhs_stride + rhs_stride).
311+
// We can just
312+
// add the strides and reuse the offset like this:
313+
// offsets[i] = lhsOffset
314+
// strides[i] = lhsStride + rhsStride
315+
316+
// Reuse offsets.
317+
offsets.push_back(lhsOffset);
318+
// Add strides.
319+
strides.push_back(addOFRs(lhsStride, rhsStride, loc, builder));
320+
}
258321
}
259-
auto newOffset = addOFRs(newLhsOffset, newRhsOffset, loc, builder);
260-
offsets.push_back(newOffset);
261322
} else {
323+
// Set stride to 1 when not continuous.
324+
strides.push_back(builder.getIndexAttr(1));
325+
// New offset is offset * stride.
326+
auto newLhsOffset = lhsState.offsets[i];
327+
auto newRhsOffset = rhsState.offsets[i];
262328
// Just propagate the unstructured offset to the result to track the
263329
// unstructured dimension. The real address calculation will be done
264330
// later in the PtrAnalysis::visitOperandAddptr.
@@ -432,13 +498,12 @@ LogicalResult PtrState::mulState(const PtrState &lhsState,
432498
assert(!lhs->dimHasModulo(i) &&
433499
"should not have non-structured dimension with modulo");
434500
if (isAnalysisingUnstructured) {
435-
auto rhsStride =
436-
expandOFRIndex(rhs->scalar, lhs->offsets[i], loc, builder);
437501
assert(!lhs->hasModulo() &&
438502
"should not have non-structured dimension with modulo");
439-
OpFoldResult newOffset =
440-
mulOFRs(lhs->offsets[i], rhsStride, loc, builder);
441-
offsets.push_back(newOffset);
503+
// Keep offsets as is for unstructured dimension.
504+
// The address calculation will be done later in structured to
505+
// memref pass.
506+
offsets.push_back(lhs->offsets[i]);
442507
// Mul the scalar to stride.
443508
OpFoldResult newStride =
444509
mulOFRs(lhs->strides[i], rhs->scalar, loc, builder);

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,27 @@ static OpFoldResult accumulateTargetOffset(Location loc,
137137
return targetOffset;
138138
}
139139

140+
static OpFoldResult accumulateTargetOffset(Location loc,
141+
ArrayRef<OpFoldResult> offsets,
142+
ArrayRef<OpFoldResult> strides,
143+
int gatherDim,
144+
OpBuilder &b) {
145+
OpFoldResult targetOffset = b.getIndexAttr(0);
146+
for (int i=0;i<offsets.size();i++) {
147+
148+
OpFoldResult offset = offsets[i];
149+
// If this is the gather dimension, multiply the offset by the stride.
150+
// Non-gather dimensions are already multiplied by the stride
151+
// in the offsets in PtrAnalysis.
152+
if (i == gatherDim) {
153+
OpFoldResult stride = strides[i];
154+
offset = mulOFRs(offset, stride, loc, b);
155+
}
156+
targetOffset = addOFRs(targetOffset, offset, loc, b);
157+
}
158+
return targetOffset;
159+
}
160+
140161
static Value rewriteGatherScatterPtrElement(
141162
ArrayRef<int64_t> resultShape, tts::MakeGatherScatterTensorPtrOp op,
142163
Value basePtr, Value gatherOffsetElt, int gatherDim,
@@ -149,7 +170,8 @@ static Value rewriteGatherScatterPtrElement(
149170

150171
auto offsets = op.getMixedOffsets();
151172
offsets[gatherDim] = gatherOffsetElt;
152-
auto targetOffset = accumulateTargetOffset(op.getLoc(), offsets, rewriter);
173+
auto targetOffset =
174+
accumulateTargetOffset(op.getLoc(), offsets, mixedStrides, gatherDim, rewriter);
153175

154176
auto staticTargetOffset = getIntAttr(targetOffset);
155177
auto resultType =

python/examples/test_index_select.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def index_select_row_with_double_mod2(input_tensor, indices, dim, mod_offset, mo
219219
o_stride_m = output_tensor.stride(0)
220220
o_stride_n = output_tensor.stride(1)
221221

222-
a = index_select_row_with_double_mod_kernel2[1,](
222+
index_select_row_with_double_mod_kernel2[1,](
223223
input_tensor,
224224
output_tensor,
225225
indices,
@@ -235,21 +235,6 @@ def index_select_row_with_double_mod2(input_tensor, indices, dim, mod_offset, mo
235235
)
236236
return output_tensor
237237

238-
index_select_row_with_mod_kernel[1,](
239-
input_tensor,
240-
output_tensor,
241-
indices,
242-
stride_i,
243-
stride_m,
244-
stride_n,
245-
o_stride_m,
246-
o_stride_n,
247-
mod_offset,
248-
BLOCK_I=R,
249-
BLOCK_N=N,
250-
)
251-
return output_tensor
252-
253238

254239
def test_index_select_row_with_double_mod2(device):
255240
M, N = 16, 16

test/Conversion/StructuredToMemref/gather_scatter_ptr_to_linalg.mlir

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
// CHECK-SAME: %[[VAL_18:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) {
3535
// CHECK: %[[VAL_19:.*]] = arith.constant 1 : index
3636
// CHECK: %[[VAL_20:.*]] = arith.constant 8 : i32
37-
// CHECK: %[[VAL_21:.*]] = arith.constant 0 : index
38-
// CHECK: %[[VAL_22:.*]] = arith.constant 16 : index
37+
// CHECK: %[[VAL_21:.*]] = arith.constant 16 : index
38+
// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
3939
// CHECK: %[[VAL_23:.*]] = tensor.empty() : tensor<16x1xi32>
4040
// CHECK: %[[VAL_24:.*]] = linalg.fill ins(%[[VAL_20]] : i32) outs(%[[VAL_23]] : tensor<16x1xi32>) -> tensor<16x1xi32>
4141
// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_5]], %[[VAL_20]] : i32
@@ -69,27 +69,27 @@
6969
// CHECK: linalg.yield %[[VAL_52]] : i32
7070
// CHECK: } -> tensor<16x1xi32>
7171
// CHECK: %[[VAL_53:.*]] = arith.index_cast %[[VAL_6]] : i64 to index
72-
// CHECK: %[[VAL_54:.*]] = arith.index_cast %[[VAL_53]] : index to i32
73-
// CHECK: %[[VAL_55:.*]] = linalg.fill ins(%[[VAL_54]] : i32) outs(%[[VAL_23]] : tensor<16x1xi32>) -> tensor<16x1xi32>
74-
// CHECK: %[[VAL_56:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_48]], %[[VAL_55]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_48]] : tensor<16x1xi32>) {
75-
// CHECK: ^bb0(%[[VAL_57:.*]]: i32, %[[VAL_58:.*]]: i32, %[[VAL_59:.*]]: i32):
76-
// CHECK: %[[VAL_60:.*]] = arith.muli %[[VAL_57]], %[[VAL_58]] : i32
77-
// CHECK: linalg.yield %[[VAL_60]] : i32
72+
// CHECK: %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_37]] : i32) outs(%[[VAL_23]] : tensor<16x1xi32>) -> tensor<16x1xi32>
73+
// CHECK: %[[VAL_55:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_32]], %[[VAL_54]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_32]] : tensor<16x1xi32>) {
74+
// CHECK: ^bb0(%[[VAL_56:.*]]: i32, %[[VAL_57:.*]]: i32, %[[VAL_58:.*]]: i32):
75+
// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_56]], %[[VAL_57]] : i32
76+
// CHECK: linalg.yield %[[VAL_59]] : i32
7877
// CHECK: } -> tensor<16x1xi32>
79-
// CHECK: %[[VAL_61:.*]] = linalg.fill ins(%[[VAL_37]] : i32) outs(%[[VAL_23]] : tensor<16x1xi32>) -> tensor<16x1xi32>
80-
// CHECK: %[[VAL_62:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_32]], %[[VAL_61]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_32]] : tensor<16x1xi32>) {
78+
// CHECK: %[[VAL_60:.*]] = arith.index_cast %[[VAL_53]] : index to i32
79+
// CHECK: %[[VAL_61:.*]] = linalg.fill ins(%[[VAL_60]] : i32) outs(%[[VAL_23]] : tensor<16x1xi32>) -> tensor<16x1xi32>
80+
// CHECK: %[[VAL_62:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_48]], %[[VAL_61]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_48]] : tensor<16x1xi32>) {
8181
// CHECK: ^bb0(%[[VAL_63:.*]]: i32, %[[VAL_64:.*]]: i32, %[[VAL_65:.*]]: i32):
82-
// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_63]], %[[VAL_64]] : i32
82+
// CHECK: %[[VAL_66:.*]] = arith.muli %[[VAL_63]], %[[VAL_64]] : i32
8383
// CHECK: linalg.yield %[[VAL_66]] : i32
8484
// CHECK: } -> tensor<16x1xi32>
8585
// CHECK: %[[VAL_67:.*]] = arith.index_cast %[[VAL_46]] : index to i32
8686
// CHECK: %[[VAL_68:.*]] = linalg.fill ins(%[[VAL_67]] : i32) outs(%[[VAL_23]] : tensor<16x1xi32>) -> tensor<16x1xi32>
87-
// CHECK: %[[VAL_69:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_62]], %[[VAL_68]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_62]] : tensor<16x1xi32>) {
87+
// CHECK: %[[VAL_69:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_55]], %[[VAL_68]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_55]] : tensor<16x1xi32>) {
8888
// CHECK: ^bb0(%[[VAL_70:.*]]: i32, %[[VAL_71:.*]]: i32, %[[VAL_72:.*]]: i32):
8989
// CHECK: %[[VAL_73:.*]] = arith.muli %[[VAL_70]], %[[VAL_71]] : i32
9090
// CHECK: linalg.yield %[[VAL_73]] : i32
9191
// CHECK: } -> tensor<16x1xi32>
92-
// CHECK: %[[VAL_74:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_56]], %[[VAL_69]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_56]] : tensor<16x1xi32>) {
92+
// CHECK: %[[VAL_74:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_62]], %[[VAL_69]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_62]] : tensor<16x1xi32>) {
9393
// CHECK: ^bb0(%[[VAL_75:.*]]: i32, %[[VAL_76:.*]]: i32, %[[VAL_77:.*]]: i32):
9494
// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_75]], %[[VAL_76]] : i32
9595
// CHECK: linalg.yield %[[VAL_78]] : i32
@@ -103,12 +103,12 @@
103103
// CHECK: } -> tensor<16x1xi32>
104104
// CHECK: %[[VAL_86:.*]] = tensor.collapse_shape %[[VAL_81]] {{\[\[}}0, 1]] : tensor<16x1xi32> into tensor<16xi32>
105105
// CHECK: %[[VAL_87:.*]] = arith.index_cast %[[VAL_25]] : i32 to index
106-
// CHECK: %[[VAL_88:.*]] = arith.minsi %[[VAL_87]], %[[VAL_22]] : index
107-
// CHECK: %[[VAL_89:.*]] = arith.maxsi %[[VAL_88]], %[[VAL_21]] : index
108-
// CHECK: %[[VAL_90:.*]] = arith.minsi %[[VAL_89]], %[[VAL_22]] : index
106+
// CHECK: %[[VAL_88:.*]] = arith.minsi %[[VAL_87]], %[[VAL_21]] : index
107+
// CHECK: %[[VAL_89:.*]] = arith.maxsi %[[VAL_88]], %[[VAL_22]] : index
108+
// CHECK: %[[VAL_90:.*]] = arith.minsi %[[VAL_89]], %[[VAL_21]] : index
109109
// CHECK: %[[VAL_91:.*]] = memref.alloc() : memref<16x16xf32>
110-
// CHECK: %[[VAL_92:.*]] = arith.minsi %[[VAL_90]], %[[VAL_22]] : index
111-
// CHECK: scf.for %[[VAL_93:.*]] = %[[VAL_21]] to %[[VAL_92]] step %[[VAL_19]] {
110+
// CHECK: %[[VAL_92:.*]] = arith.minsi %[[VAL_90]], %[[VAL_21]] : index
111+
// CHECK: scf.for %[[VAL_93:.*]] = %[[VAL_22]] to %[[VAL_92]] step %[[VAL_19]] {
112112
// CHECK: %[[VAL_94:.*]] = tensor.extract %[[VAL_86]]{{\[}}%[[VAL_93]]] : tensor<16xi32>
113113
// CHECK: %[[VAL_95:.*]] = arith.index_cast %[[VAL_94]] : i32 to index
114114
// CHECK: %[[VAL_96:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_95]]], sizes: [1, 16], strides: [1, 1] : memref<*xf32> to memref<1x16xf32, strided<[1, 1], offset: ?>>
@@ -133,7 +133,7 @@
133133
// CHECK: } -> tensor<16x1xi32>
134134
// CHECK: %[[VAL_114:.*]] = arith.index_cast %[[VAL_105]] : index to i32
135135
// CHECK: %[[VAL_115:.*]] = linalg.fill ins(%[[VAL_114]] : i32) outs(%[[VAL_23]] : tensor<16x1xi32>) -> tensor<16x1xi32>
136-
// CHECK: %[[VAL_116:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_62]], %[[VAL_115]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_62]] : tensor<16x1xi32>) {
136+
// CHECK: %[[VAL_116:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_55]], %[[VAL_115]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_55]] : tensor<16x1xi32>) {
137137
// CHECK: ^bb0(%[[VAL_117:.*]]: i32, %[[VAL_118:.*]]: i32, %[[VAL_119:.*]]: i32):
138138
// CHECK: %[[VAL_120:.*]] = arith.muli %[[VAL_117]], %[[VAL_118]] : i32
139139
// CHECK: linalg.yield %[[VAL_120]] : i32
@@ -151,7 +151,7 @@
151151
// CHECK: linalg.yield %[[VAL_132]] : i32
152152
// CHECK: } -> tensor<16x1xi32>
153153
// CHECK: %[[VAL_133:.*]] = tensor.collapse_shape %[[VAL_128]] {{\[\[}}0, 1]] : tensor<16x1xi32> into tensor<16xi32>
154-
// CHECK: scf.for %[[VAL_134:.*]] = %[[VAL_21]] to %[[VAL_92]] step %[[VAL_19]] {
154+
// CHECK: scf.for %[[VAL_134:.*]] = %[[VAL_22]] to %[[VAL_92]] step %[[VAL_19]] {
155155
// CHECK: %[[VAL_135:.*]] = tensor.extract %[[VAL_133]]{{\[}}%[[VAL_134]]] : tensor<16xi32>
156156
// CHECK: %[[VAL_136:.*]] = arith.index_cast %[[VAL_135]] : i32 to index
157157
// CHECK: %[[VAL_137:.*]] = tensor.extract_slice %[[VAL_99]]{{\[}}%[[VAL_134]], 0] [1, 8] [1, 1] : tensor<16x16xf32> to tensor<1x8xf32>

0 commit comments

Comments
 (0)