Skip to content

Commit cc96f5e

Browse files
committed
populate stride info for ouput of memref.reinterpret_cast when lowering make_ptr_blk
1 parent 7f38361 commit cc96f5e

26 files changed

Lines changed: 104 additions & 131 deletions

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ struct MakeTensorPtrConverter
108108
auto strideIntAttr = getIntAttr(stride);
109109
if (size == 1 && strideIntAttr && strideIntAttr.value() == 0) {
110110
strides.push_back(b.getIndexAttr(accumulate));
111+
} else if (auto v = llvm::dyn_cast_if_present<Value>(stride)) {
112+
OpFoldResult result = getAsOpFoldResult(v);
113+
strides.push_back(result);
111114
} else {
112115
strides.push_back(stride);
113116
}

test/Conversion/StructuredToMemref/addptr_2d_example.mlir

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,22 @@ module {
4848
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)>
4949
// CHECK-LABEL: func.func @kernel
5050
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
51-
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index
5251
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
5352
// CHECK-NOT: separator of consecutive DAGs
54-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
53+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
5554
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x256xbf16>
56-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
55+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
5756
// CHECK-DAG: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x256xbf16>
58-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
57+
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
5958
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16>
60-
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
59+
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
6160
// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16>
6261
// CHECK: [[VAR_3_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_1_]], [[VAR_2_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_1_]] : tensor<4x256xbf16>) {
6362
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
6463
// CHECK: [[VAR_4_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : bf16
6564
// CHECK: linalg.yield [[VAR_4_]] : bf16
6665
// CHECK: } -> tensor<4x256xbf16>
67-
// CHECK: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
68-
// CHECK: bufferization.materialize_in_destination [[VAR_3_]] in writable [[VAR_reinterpret_cast_2_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> ()
66+
// CHECK: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
67+
// CHECK: bufferization.materialize_in_destination [[VAR_3_]] in writable [[VAR_reinterpret_cast_2_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, 5], offset: ?>>) -> ()
6968
// CHECK: return
7069
// CHECK: }

test/Conversion/StructuredToMemref/addptr_add_value.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,16 @@ module {
4949

5050
// CHECK-LABEL: func.func @kernel
5151
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
52-
// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index
5352
// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index
5453
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
5554
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
5655
// CHECK: [[VAR_2_:%.+]] = arith.addi [[VAR_0_]], [[VAR_1_]] : index
5756
// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_2_]], [[CST_10_]] : index
58-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_6_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
59-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_6_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
57+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>>
58+
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>>
6059
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x256xbf16>
61-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
60+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16>
6261
// CHECK: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x256xbf16>
63-
// CHECK: bufferization.materialize_in_destination [[VAR_4_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> ()
62+
// CHECK: bufferization.materialize_in_destination [[VAR_4_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, 6], offset: ?>>) -> ()
6463
// CHECK: return
6564
// CHECK: }

test/Conversion/StructuredToMemref/addptr_dim1.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ module {
6969
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
7070
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
7171
// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32
72-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
7372
// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16
7473
// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : index
7574
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4x256xbf16>
@@ -86,9 +85,9 @@ module {
8685
// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg8_]] : index to i32
8786
// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[CST_256_]] : i32
8887
// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index
89-
// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], [[CST_1_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
88+
// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], 1{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1], offset: ?>>
9089
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16>
91-
// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16>
90+
// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, 1], offset: ?>> to memref<4x256xbf16>
9291
// CHECK: [[VAR_8_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16>
9392
// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg9_]], [[VAR_8_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg9_]] : tensor<4x256xbf16>) {
9493
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
@@ -98,7 +97,7 @@ module {
9897
// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_arg10_]], [[CST_256_1_]] : index
9998
// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, index
10099
// CHECK: }
101-
// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: {{.}}[[CST_256_1_]], 1] : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1]>>
102-
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in writable [[VAR_reinterpret_cast_1_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, 1]>>) -> ()
100+
// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: [256, 1] : memref<*xbf16> to memref<4x256xbf16, strided<[256, 1]>>
101+
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in writable [[VAR_reinterpret_cast_1_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[256, 1]>>) -> ()
103102
// CHECK: return
104103
// CHECK: }

test/Conversion/StructuredToMemref/addptr_for_accumulation.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,19 @@ module {
6262
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)>
6363
// CHECK-LABEL: func.func @kernel
6464
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32) {
65-
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index
6665
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
6766
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
6867
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
69-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
7068
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
7169
// CHECK-NOT: separator of consecutive DAGs
72-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
70+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
7371
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x256xbf16>
74-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
72+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
7573
// CHECK: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x256xbf16>
7674
// CHECK-DAG: [[VAR_2_:%.+]]:2 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_1_]], [[VAR_arg13_:%.+]] = [[VAR_0_]]) -> (tensor<4x256xbf16>, index) {
77-
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4, 256], strides: {{.}}[[CST_1_]], [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
75+
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
7876
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16>
79-
// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16>
77+
// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
8078
// CHECK: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16>
8179
// CHECK: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg12_]], [[VAR_3_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg12_]] : tensor<4x256xbf16>) {
8280
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
@@ -86,7 +84,7 @@ module {
8684
// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg13_]], [[CST_3_]] : index
8785
// CHECK: scf.yield [[VAR_4_]], [[VAR_5_]] : tensor<4x256xbf16>, index
8886
// CHECK: }
89-
// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
90-
// CHECK: bufferization.materialize_in_destination [[VAR_2_]]#0 in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> ()
87+
// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
88+
// CHECK: bufferization.materialize_in_destination [[VAR_2_]]#0 in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, 5], offset: ?>>) -> ()
9189
// CHECK: return
9290
// CHECK: }

test/Conversion/StructuredToMemref/addptr_for_expand_ptr.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,15 @@ module {
5959
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
6060
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
6161
// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index
62-
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
6362
// CHECK-NOT: separator of consecutive DAGs
6463
// CHECK-DAG: [[VAR_0_:%.+]] = scf.for [[VAR_arg7_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg8_:%.+]] = [[CST_1024_]]) -> (index) {
6564
// CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_arg8_]], [[CST_256_]] : index
6665
// CHECK-NOT: separator of consecutive DAGs
67-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [256, 256], strides: {{.}}[[CST_2_]], 1] : memref<*xbf16> to memref<256x256xbf16, strided<[?, 1], offset: ?>>
66+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [256, 256], strides: [2, 1] : memref<*xbf16> to memref<256x256xbf16, strided<[2, 1], offset: ?>>
6867
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256x256xbf16>
69-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<256x256xbf16, strided<[?, 1], offset: ?>> to memref<256x256xbf16>
68+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<256x256xbf16, strided<[2, 1], offset: ?>> to memref<256x256xbf16>
7069
// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256x256xbf16>
71-
// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256x256xbf16>, memref<256x256xbf16, strided<[?, 1], offset: ?>>) -> ()
70+
// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256x256xbf16>, memref<256x256xbf16, strided<[2, 1], offset: ?>>) -> ()
7271
// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg8_]], [[CST_3_]] : index
7372
// CHECK: scf.yield [[VAR_3_]] : index
7473
// CHECK: }

test/Conversion/StructuredToMemref/addptr_for_more_init_args.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ module {
5252
// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index
5353
// CHECK-NOT: separator of consecutive DAGs
5454
// CHECK-DAG: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg9_:%.+]] = [[CST_1_]], [[VAR_arg10_:%.+]] = [[CST_1024_]], [[VAR_arg11_:%.+]] = [[CST_2_]], [[VAR_arg12_:%.+]] = [[CST_1024_]], [[VAR_arg13_:%.+]] = [[CST_3_]]) -> (index, index, index, index, index) {
55-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg12_]]{{.}}, sizes: [256], strides: {{.}}[[CST_1_]]{{.}} : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
56-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [256], strides: {{.}}[[CST_1_]]{{.}} : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
55+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg12_]]{{.}}, sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>>
56+
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>>
5757
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256xbf16>
58-
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16>
58+
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<256xbf16, strided<[1], offset: ?>> to memref<256xbf16>
5959
// CHECK: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256xbf16>
60-
// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256xbf16>, memref<256xbf16, strided<[?], offset: ?>>) -> ()
60+
// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256xbf16>, memref<256xbf16, strided<[1], offset: ?>>) -> ()
6161
// CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_arg10_]], [[CST_3_]] : index
6262
// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_arg9_]], [[CST_3_]] : index
6363
// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_arg11_]], [[CST_3_]] : index

0 commit comments

Comments
 (0)