Skip to content

Commit 0c3e629

Browse files
authored
Update triton to cb78503 (#247)
1 parent a53ef0f commit 0c3e629

4 files changed

Lines changed: 14 additions & 400 deletions

File tree

python/examples/test_scalar_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def compile():
2424
src = triton.compiler.ASTSource(
2525
fn=test_scalar_store,
2626
signature="*fp32",
27-
constants={
27+
constexprs={
2828
"BLOCK_SIZE": 8
2929
}
3030
)

test/Conversion/StructuredToMemref/nested_loops.mlir

Lines changed: 6 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -179,173 +179,9 @@ module {
179179
}
180180
}
181181

182-
// CHECK-LABEL: func.func @nested2_complex_body
183-
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
184-
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32
185-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
186-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32
187-
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
188-
// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index
189-
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
190-
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
191-
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
192-
// CHECK: [[VAR_2_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2_]] : i32
193-
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
194-
// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg10_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg11_:%.+]] = [[CST_0_1_]], [[VAR_arg12_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 {
195-
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg11_]], [[CST_1_1_]] : index
196-
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg12_]], [[CST_1_1_]] : index
197-
// CHECK-NOT: separator of consecutive DAGs
198-
// CHECK-DAG: [[VAR_7_:%.+]]:2 = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg14_:%.+]] = [[VAR_5_]], [[VAR_arg15_:%.+]] = [[VAR_6_]]) -> (index, index) : i32 {
199-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg15_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
200-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg14_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
201-
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<2x2xf32>
202-
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
203-
// CHECK: [[VAR_12_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<2x2xf32>
204-
// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
205-
// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_arg14_]], [[CST_3_]] : index
206-
// CHECK-DAG: [[VAR_14_:%.+]] = arith.addi [[VAR_arg15_]], [[CST_3_]] : index
207-
// CHECK: scf.yield [[VAR_13_]], [[VAR_14_]] : index, index
208-
// CHECK: }
209-
// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_arg11_]], [[VAR_3_]] : index
210-
// CHECK-DAG: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[CST_1_1_]] : index
211-
// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_arg12_]], [[VAR_3_]] : index
212-
// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_10_]], [[CST_1_1_]] : index
213-
// CHECK: scf.yield [[VAR_9_]], [[VAR_11_]] : index, index
214-
// CHECK: }
215-
// CHECK: return
216-
// CHECK: }
217-
//
218-
// CHECK-LABEL: func.func @nested2_use_loop_results
219-
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
220-
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32
221-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
222-
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32
223-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32
224-
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
225-
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
226-
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
227-
// CHECK: [[VAR_2_:%.+]] = arith.muli [[PARAM_3_]], [[CST_4_]] : i32
228-
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
229-
// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg10_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg11_:%.+]] = [[CST_0_1_]], [[VAR_arg12_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 {
230-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg12_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
231-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg11_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
232-
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<2x2xf32>
233-
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
234-
// CHECK: [[VAR_5_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<2x2xf32>
235-
// CHECK: bufferization.materialize_in_destination [[VAR_5_]] in writable [[VAR_reinterpret_cast_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
236-
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg11_]], [[VAR_3_]] : index
237-
// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_arg12_]], [[VAR_3_]] : index
238-
// CHECK-NOT: separator of consecutive DAGs
239-
// CHECK-DAG: [[VAR_8_:%.+]]:2 = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg14_:%.+]] = [[VAR_6_]], [[VAR_arg15_:%.+]] = [[VAR_7_]]) -> (index, index) : i32 {
240-
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg15_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
241-
// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg14_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
242-
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<2x2xf32>
243-
// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
244-
// CHECK: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<2x2xf32>
245-
// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
246-
// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_arg14_]], [[VAR_3_]] : index
247-
// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_3_]] : index
248-
// CHECK: scf.yield [[VAR_10_]], [[VAR_11_]] : index, index
249-
// CHECK: }
250-
// CHECK: scf.yield [[VAR_8_]]#0, [[VAR_8_]]#1 : index, index
251-
// CHECK: }
252-
// CHECK: return
253-
// CHECK: }
254-
//
255-
// CHECK-LABEL: func.func @nested3
256-
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
257-
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32
258-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
259-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32
260-
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
261-
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
262-
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
263-
// CHECK-NOT: separator of consecutive DAGs
264-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
265-
// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_]] : i32
266-
// CHECK-NOT: separator of consecutive DAGs
267-
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
268-
// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg10_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg11_:%.+]] = [[CST_0_1_]], [[VAR_arg12_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg13_:%.+]] = [[CST_0_1_]]) -> (index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index) : i32 {
269-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg11_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
270-
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<2x2xf32>
271-
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
272-
// CHECK-DAG: [[VAR_5_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<2x2xf32>
273-
// CHECK-DAG: [[VAR_6_:%.+]]:3 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_arg11_]], [[VAR_arg16_:%.+]] = [[VAR_arg12_]], [[VAR_arg17_:%.+]] = [[VAR_arg13_]]) -> (index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index) : i32 {
274-
// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_3_]] : index
275-
// CHECK-NOT: separator of consecutive DAGs
276-
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
277-
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<2x2xf32>
278-
// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
279-
// CHECK-DAG: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<2x2xf32>
280-
// CHECK-DAG: [[VAR_10_:%.+]]:3 = scf.for [[VAR_arg18_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg19_:%.+]] = [[VAR_8_]], [[VAR_arg20_:%.+]] = [[VAR_arg16_]], [[VAR_arg21_:%.+]] = [[VAR_arg17_]]) -> (index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index) : i32 {
281-
// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg21_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
282-
// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg19_]], [[VAR_3_]] : index
283-
// CHECK-NOT: separator of consecutive DAGs
284-
// CHECK-DAG: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
285-
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<2x2xf32>
286-
// CHECK: memref.copy [[VAR_reinterpret_cast_4_]], [[RES_2_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
287-
// CHECK: [[VAR_12_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<2x2xf32>
288-
// CHECK: bufferization.materialize_in_destination [[VAR_5_]] in writable [[VAR_reinterpret_cast_3_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
289-
// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_arg21_]], [[VAR_3_]] : index
290-
// CHECK: [[VAR_reinterpret_cast_6_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
291-
// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_6_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
292-
// CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_13_]], [[VAR_3_]] : index
293-
// CHECK: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_14_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
294-
// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_7_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
295-
// CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_14_]], [[VAR_3_]] : index
296-
// CHECK: [[VAR_reinterpret_cast_8_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
297-
// CHECK: scf.yield [[VAR_11_]], [[VAR_reinterpret_cast_8_]], [[VAR_15_]] : index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index
298-
// CHECK: }
299-
// CHECK: scf.yield [[VAR_10_]]#0, [[VAR_10_]]#1, [[VAR_10_]]#2 : index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index
300-
// CHECK: }
301-
// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_6_]]#0, [[VAR_3_]] : index
302-
// CHECK: scf.yield [[VAR_7_]], [[VAR_6_]]#1, [[VAR_6_]]#2 : index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index
303-
// CHECK: }
304-
// CHECK: return
305-
// CHECK: }
306-
//
307-
// CHECK-LABEL: func.func @nested_use_same_level_loop_result
308-
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
309-
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32
310-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
311-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32
312-
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
313-
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
314-
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
315-
// CHECK-NOT: separator of consecutive DAGs
316-
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
317-
// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_]] : i32
318-
// CHECK-NOT: separator of consecutive DAGs
319-
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
320-
// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg10_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg11_:%.+]] = [[CST_0_1_]], [[VAR_arg12_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg13_:%.+]] = [[CST_0_1_]]) -> (index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index) : i32 {
321-
// CHECK-DAG: [[VAR_5_:%.+]] = scf.for [[VAR_arg14_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_arg11_]]) -> (index) : i32 {
322-
// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_3_]] : index
323-
// CHECK: scf.yield [[VAR_8_]] : index
324-
// CHECK: }
325-
// CHECK-DAG: [[VAR_6_:%.+]]:3 = scf.for [[VAR_arg14_1_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg15_1_:%.+]] = [[VAR_5_]], [[VAR_arg16_:%.+]] = [[VAR_arg12_]], [[VAR_arg17_:%.+]] = [[VAR_arg13_]]) -> (index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index) : i32 {
326-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg17_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
327-
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg15_1_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
328-
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<2x2xf32>
329-
// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
330-
// CHECK-DAG: [[VAR_8_1_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<2x2xf32>
331-
// CHECK-DAG: [[VAR_9_:%.+]] = arith.addi [[VAR_arg15_1_]], [[VAR_3_]] : index
332-
// CHECK-NOT: separator of consecutive DAGs
333-
// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_9_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
334-
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<2x2xf32>
335-
// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32>
336-
// CHECK: [[VAR_10_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<2x2xf32>
337-
// CHECK: bufferization.materialize_in_destination [[VAR_8_1_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
338-
// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_3_]] : index
339-
// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_11_]], [[VAR_3_]] : index
340-
// CHECK: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
341-
// CHECK: bufferization.materialize_in_destination [[VAR_10_]] in writable [[VAR_reinterpret_cast_4_]] : (tensor<2x2xf32>, memref<2x2xf32, strided<[?, ?], offset: ?>>) -> ()
342-
// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_12_]], [[VAR_3_]] : index
343-
// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: [2, 2], strides: {{.}}[[VAR_0_]], [[VAR_1_]]{{.}} : memref<*xf32> to memref<2x2xf32, strided<[?, ?], offset: ?>>
344-
// CHECK-DAG: [[VAR_14_:%.+]] = arith.addi [[VAR_9_]], [[VAR_3_]] : index
345-
// CHECK: scf.yield [[VAR_14_]], [[VAR_reinterpret_cast_5_]], [[VAR_13_]] : index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index
346-
// CHECK: }
347-
// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_6_]]#0, [[VAR_3_]] : index
348-
// CHECK: scf.yield [[VAR_7_]], [[VAR_6_]]#1, [[VAR_6_]]#2 : index, memref<2x2xf32, strided<[?, ?], offset: ?>>, index
349-
// CHECK: }
350-
// CHECK: return
351-
// CHECK: }
182+
// CHECK-NOT: tt.addptr
183+
// CHECK-NOT: tt.load
184+
// CHECK-NOT: tt.store
185+
186+
// CHECK-COUNT-20: memref.reinterpret_cast %arg{{[0-9]+}}
187+
// CHECK-NOT: memref.reinterpret_cast %arg{{[0-9]+}}

0 commit comments

Comments
 (0)