Skip to content

Commit a53ef0f

Browse files
authored
Set stride info for output type when lowering make_blk_ptr from TTIR to MLIR (#245)
When tensors are represented using make_block_ptr in Triton, during the TTIR to MLIR lowering process, the stride information of the output resulting from memref.reinterpret_cast is dynamic. Although the shape and access information are known at compile time, some dynamic information is still being generated. ``` # Triton code in_block_ptr = tl.make_block_ptr( base=x_ptr, shape=(B, H, W, D), strides=(H * W * D, W * D, D, 1), offsets=(0, 0, 0, 0), block_shape=(B, H, W, D), order=(3, 2, 1, 0), ) ``` ``` # MLIR code %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [4, 2, 2, 128], strides: [%c512, %c256, %c128, %c1] : memref<xf16> to memref<4x2x2x128xf16, strided<[?, ?, ?, ?], offset: ?>> ``` IIf the first instance of MLIR code contains dynamic stride information inside memref.reinterpret_cast, the lowering pipeline generates code assuming the stride information is unknown. Consequently, the resulting code gets scalarized at some stage of the pipeline. The changes in this PR allow us to set stride information when converting TTIR code to MLIR code, ensuring the availability of stride information in the first instance of MLIR code. These changes enable us to generate vector code when using make_blk_ptr.
1 parent 19fb6f3 commit a53ef0f

26 files changed

Lines changed: 104 additions & 131 deletions

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ struct MakeTensorPtrConverter
108108
auto strideIntAttr = getIntAttr(stride);
109109
if (size == 1 && strideIntAttr && strideIntAttr.value() == 0) {
110110
strides.push_back(b.getIndexAttr(accumulate));
111+
} else if (auto v = llvm::dyn_cast_if_present<Value>(stride)) {
112+
OpFoldResult result = getAsOpFoldResult(v);
113+
strides.push_back(result);
111114
} else {
112115
strides.push_back(stride);
113116
}

test/Conversion/StructuredToMemref/addptr_2d_example.mlir

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,22 @@ module {
4848
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)>
4949
// CHECK-LABEL: func.func @kernel
5050
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
51-
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index
5251
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
5352
// CHECK-NOT: separator of consecutive DAGs
54-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
53+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
5554
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x256xbf16>
56-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
55+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
5756
// CHECK-DAG: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x256xbf16>
58-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
57+
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
5958
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16>
60-
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
59+
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
6160
// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16>
6261
// CHECK: [[VAR_3_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_1_]], [[VAR_2_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_1_]] : tensor<4x256xbf16>) {
6362
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
6463
// CHECK: [[VAR_4_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : bf16
6564
// CHECK: linalg.yield [[VAR_4_]] : bf16
6665
// CHECK: } -> tensor<4x256xbf16>
67-
// CHECK: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
68-
// CHECK: bufferization.materialize_in_destination [[VAR_3_]] in writable [[VAR_reinterpret_cast_2_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> ()
66+
// CHECK: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
67+
// CHECK: bufferization.materialize_in_destination [[VAR_3_]] in writable [[VAR_reinterpret_cast_2_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, 5], offset: ?>>) -> ()
6968
// CHECK: return
7069
// CHECK: }

test/Conversion/StructuredToMemref/addptr_add_value.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,16 @@ module {
4949

5050
// CHECK-LABEL: func.func @kernel
5151
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
52-
// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index
5352
// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index
5453
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
5554
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
5655
// CHECK: [[VAR_2_:%.+]] = arith.addi [[VAR_0_]], [[VAR_1_]] : index
5756
// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_2_]], [[CST_10_]] : index
58-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_6_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
59-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_6_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
57+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>>
58+
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4, 256], strides: [1, 6] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 6], offset: ?>>
6059
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x256xbf16>
61-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
60+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16>
6261
// CHECK: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x256xbf16>
63-
// CHECK: bufferization.materialize_in_destination [[VAR_4_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> ()
62+
// CHECK: bufferization.materialize_in_destination [[VAR_4_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, 6], offset: ?>>) -> ()
6463
// CHECK: return
6564
// CHECK: }

test/Conversion/StructuredToMemref/addptr_dim1.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ module {
6969
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
7070
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
7171
// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32
72-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
7372
// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16
7473
// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : index
7574
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4x256xbf16>
@@ -86,9 +85,9 @@ module {
8685
// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg8_]] : index to i32
8786
// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[CST_256_]] : i32
8887
// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index
89-
// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], [[CST_1_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
88+
// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], 1{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1], offset: ?>>
9089
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16>
91-
// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16>
90+
// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, 1], offset: ?>> to memref<4x256xbf16>
9291
// CHECK: [[VAR_8_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16>
9392
// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg9_]], [[VAR_8_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg9_]] : tensor<4x256xbf16>) {
9493
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
@@ -98,7 +97,7 @@ module {
9897
// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_arg10_]], [[CST_256_1_]] : index
9998
// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, index
10099
// CHECK: }
101-
// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: {{.}}[[CST_256_1_]], 1] : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1]>>
102-
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in writable [[VAR_reinterpret_cast_1_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, 1]>>) -> ()
100+
// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: [256, 1] : memref<*xbf16> to memref<4x256xbf16, strided<[256, 1]>>
101+
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in writable [[VAR_reinterpret_cast_1_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[256, 1]>>) -> ()
103102
// CHECK: return
104103
// CHECK: }

test/Conversion/StructuredToMemref/addptr_for_accumulation.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,19 @@ module {
6262
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)>
6363
// CHECK-LABEL: func.func @kernel
6464
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32) {
65-
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index
6665
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
6766
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
6867
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
69-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
7068
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
7169
// CHECK-NOT: separator of consecutive DAGs
72-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
70+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
7371
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x256xbf16>
74-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16>
72+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
7573
// CHECK: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x256xbf16>
7674
// CHECK-DAG: [[VAR_2_:%.+]]:2 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_1_]], [[VAR_arg13_:%.+]] = [[VAR_0_]]) -> (tensor<4x256xbf16>, index) {
77-
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4, 256], strides: {{.}}[[CST_1_]], [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
75+
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
7876
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16>
79-
// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16>
77+
// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16>
8078
// CHECK: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16>
8179
// CHECK: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg12_]], [[VAR_3_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg12_]] : tensor<4x256xbf16>) {
8280
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
@@ -86,7 +84,7 @@ module {
8684
// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg13_]], [[CST_3_]] : index
8785
// CHECK: scf.yield [[VAR_4_]], [[VAR_5_]] : tensor<4x256xbf16>, index
8886
// CHECK: }
89-
// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
90-
// CHECK: bufferization.materialize_in_destination [[VAR_2_]]#0 in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ?], offset: ?>>) -> ()
87+
// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5], offset: ?>>
88+
// CHECK: bufferization.materialize_in_destination [[VAR_2_]]#0 in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, 5], offset: ?>>) -> ()
9189
// CHECK: return
9290
// CHECK: }

test/Conversion/StructuredToMemref/addptr_for_expand_ptr.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,15 @@ module {
5959
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
6060
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
6161
// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index
62-
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
6362
// CHECK-NOT: separator of consecutive DAGs
6463
// CHECK-DAG: [[VAR_0_:%.+]] = scf.for [[VAR_arg7_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg8_:%.+]] = [[CST_1024_]]) -> (index) {
6564
// CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_arg8_]], [[CST_256_]] : index
6665
// CHECK-NOT: separator of consecutive DAGs
67-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [256, 256], strides: {{.}}[[CST_2_]], 1] : memref<*xbf16> to memref<256x256xbf16, strided<[?, 1], offset: ?>>
66+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [256, 256], strides: [2, 1] : memref<*xbf16> to memref<256x256xbf16, strided<[2, 1], offset: ?>>
6867
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256x256xbf16>
69-
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<256x256xbf16, strided<[?, 1], offset: ?>> to memref<256x256xbf16>
68+
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<256x256xbf16, strided<[2, 1], offset: ?>> to memref<256x256xbf16>
7069
// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256x256xbf16>
71-
// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256x256xbf16>, memref<256x256xbf16, strided<[?, 1], offset: ?>>) -> ()
70+
// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256x256xbf16>, memref<256x256xbf16, strided<[2, 1], offset: ?>>) -> ()
7271
// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg8_]], [[CST_3_]] : index
7372
// CHECK: scf.yield [[VAR_3_]] : index
7473
// CHECK: }

test/Conversion/StructuredToMemref/addptr_for_more_init_args.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ module {
5252
// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index
5353
// CHECK-NOT: separator of consecutive DAGs
5454
// CHECK-DAG: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg9_:%.+]] = [[CST_1_]], [[VAR_arg10_:%.+]] = [[CST_1024_]], [[VAR_arg11_:%.+]] = [[CST_2_]], [[VAR_arg12_:%.+]] = [[CST_1024_]], [[VAR_arg13_:%.+]] = [[CST_3_]]) -> (index, index, index, index, index) {
55-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg12_]]{{.}}, sizes: [256], strides: {{.}}[[CST_1_]]{{.}} : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
56-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [256], strides: {{.}}[[CST_1_]]{{.}} : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
55+
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg12_]]{{.}}, sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>>
56+
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg10_]]{{.}}, sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1], offset: ?>>
5757
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256xbf16>
58-
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16>
58+
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<256xbf16, strided<[1], offset: ?>> to memref<256xbf16>
5959
// CHECK: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256xbf16>
60-
// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256xbf16>, memref<256xbf16, strided<[?], offset: ?>>) -> ()
60+
// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_]] : (tensor<256xbf16>, memref<256xbf16, strided<[1], offset: ?>>) -> ()
6161
// CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_arg10_]], [[CST_3_]] : index
6262
// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_arg9_]], [[CST_3_]] : index
6363
// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_arg11_]], [[CST_3_]] : index

0 commit comments

Comments
 (0)