@@ -62,21 +62,19 @@ module {
6262// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)>
6363// CHECK-LABEL: func.func @kernel
6464// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32) {
65- // CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index
6665// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
6766// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
6867// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
69- // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
7068// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
7169// CHECK-NOT: separator of consecutive DAGs
72- // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ? ], offset: ?>>
70+ // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5 ], offset: ?>>
7371// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x256xbf16>
74- // CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, ? ], offset: ?>> to memref<4x256xbf16>
72+ // CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x256xbf16, strided<[1, 5 ], offset: ?>> to memref<4x256xbf16>
7573// CHECK: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x256xbf16>
7674// CHECK-DAG: [[VAR_2_:%.+]]:2 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_1_]], [[VAR_arg13_:%.+]] = [[VAR_0_]]) -> (tensor<4x256xbf16>, index) {
77- // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4, 256], strides: {{.}}[[CST_1_]], [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ? ], offset: ?>>
75+ // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5 ], offset: ?>>
7876// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16>
79- // CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ? ], offset: ?>> to memref<4x256xbf16>
77+ // CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4x256xbf16, strided<[1, 5 ], offset: ?>> to memref<4x256xbf16>
8078// CHECK: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16>
8179// CHECK: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg12_]], [[VAR_3_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg12_]] : tensor<4x256xbf16>) {
8280// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
@@ -86,7 +84,7 @@ module {
8684// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg13_]], [[CST_3_]] : index
8785// CHECK: scf.yield [[VAR_4_]], [[VAR_5_]] : tensor<4x256xbf16>, index
8886// CHECK: }
89- // CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, [[CST_5_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[1, ? ], offset: ?>>
90- // CHECK: bufferization.materialize_in_destination [[VAR_2_]]#0 in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, ? ], offset: ?>>) -> ()
87+ // CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_0_]]{{.}}, sizes: [4, 256], strides: [1, 5] : memref<*xbf16> to memref<4x256xbf16, strided<[1, 5 ], offset: ?>>
88+ // CHECK: bufferization.materialize_in_destination [[VAR_2_]]#0 in writable [[VAR_reinterpret_cast_0_]] : (tensor<4x256xbf16>, memref<4x256xbf16, strided<[1, 5 ], offset: ?>>) -> ()
9189// CHECK: return
9290// CHECK: }
0 commit comments