Skip to content

Commit c3f2342

Browse files
authored
Fix incorrect shape analysis for wraparound stacked load (#355)
If clampedOff does not overflow, then d1 may become greater than rowSize. In such cases, d1 should be restricted with `d1 = min(d1, rowSize)` Additionally, redundant minsi and subi operations in the subview have been simplified by using the block shape directly.
1 parent 2b728ad commit c3f2342

5 files changed

Lines changed: 260 additions & 31 deletions

File tree

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,30 @@ struct MakeTensorPtrConverter
409409
// clampedOff - targetOffset
410410
// d1 = --------------------
411411
// strideRows
412+
//
413+
////////////////////////////////////////////////////////////////////////////
414+
//
415+
// cols
416+
//
417+
// wrappedAroundOff
418+
// --------------*---------------------
419+
// | |
420+
// | targetOffset |
421+
// | *------------| |
422+
// | | | |
423+
// | | | |
424+
// rows| rowSize | | |
425+
// | | | |
426+
// | | | |
427+
// | *------------| |
428+
// | nextOff |
429+
// | |
430+
// | clampedOff |
431+
// --------------*---------------------
432+
//
433+
// For the case that clampedOff is not overflown
434+
// d1 = min(d1, rowSize)
435+
//
412436

413437
auto resultType = getResultMemrefType(
414438
op, /* offset */ ShapedType::kDynamic,
@@ -443,6 +467,7 @@ struct MakeTensorPtrConverter
443467
rewriter.create<arith::AddIOp>(loc, modRow, wrappedAroundOff);
444468
Value d1 = rewriter.create<arith::SubIOp>(loc, clampedOff, targetOffset);
445469
d1 = rewriter.create<arith::DivSIOp>(loc, d1, strideRow);
470+
d1 = rewriter.create<arith::MinSIOp>(loc, d1, rowSize);
446471

447472
SmallVector<Value> sizes1{d1, colSize};
448473
memref::ReinterpretCastOp cast1 =
@@ -685,11 +710,10 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
685710
ConversionPatternRewriter &rewriter) const {
686711
OpFoldResult subviewRowFull = dims[0];
687712
OpFoldResult subviewColFull = dims[1];
688-
OpFoldResult col1 =
713+
OpFoldResult subviewCol1 =
689714
rewriter.create<memref::DimOp>(loc, block1, 1).getResult();
690-
OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter);
691715
OpFoldResult subviewCol2 =
692-
subOFRs(subviewColFull, subviewCol1, loc, rewriter);
716+
rewriter.create<memref::DimOp>(loc, block2, 1).getResult();
693717

694718
SmallVector<OpFoldResult> offsets(dims.size(), rewriter.getIndexAttr(0));
695719
SmallVector<OpFoldResult> strides(dims.size(), rewriter.getIndexAttr(1));
@@ -707,11 +731,10 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
707731
ConversionPatternRewriter &rewriter) const {
708732
OpFoldResult subviewRowFull = dims[0];
709733
OpFoldResult subviewColFull = dims[1];
710-
OpFoldResult row1 =
734+
OpFoldResult subviewRow1 =
711735
rewriter.create<memref::DimOp>(loc, block1, 0).getResult();
712-
OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter);
713736
OpFoldResult subviewRow2 =
714-
subOFRs(subviewRowFull, subviewRow1, loc, rewriter);
737+
rewriter.create<memref::DimOp>(loc, block2, 0).getResult();
715738

716739
SmallVector<OpFoldResult> offsets(dims.size(), rewriter.getIndexAttr(0));
717740
SmallVector<OpFoldResult> strides(dims.size(), rewriter.getIndexAttr(1));

python/examples/test_mm.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
2+
import torch
3+
import triton
4+
import pytest
5+
import triton.language as tl
6+
import benchmark
7+
8+
@triton.jit
9+
def prev_multiple_of(a, b):
10+
# the largest x<a that x%b ==0
11+
return tl.cdiv(a, b) * b - b
12+
13+
@triton.jit
14+
def mm_kernel(
15+
A,
16+
B,
17+
C,
18+
M,
19+
N,
20+
K,
21+
stride_am,
22+
stride_ak,
23+
stride_bk,
24+
stride_bn,
25+
stride_cm,
26+
stride_cn,
27+
BLOCK_M: tl.constexpr,
28+
BLOCK_N: tl.constexpr,
29+
BLOCK_K: tl.constexpr,
30+
GROUP_M: tl.constexpr,
31+
):
32+
# matrix multiplication
33+
pid = tl.program_id(0)
34+
grid_m = tl.cdiv(M, BLOCK_M)
35+
grid_n = tl.cdiv(N, BLOCK_N)
36+
# re-order program ID for better L2 performance
37+
width = GROUP_M * grid_n
38+
group_id = pid // width
39+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
40+
pid_m = group_id * GROUP_M + (pid % group_size)
41+
pid_n = (pid % width) // (group_size)
42+
# do matrix multiplication
43+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
44+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
45+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
46+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
47+
prev_multiple = prev_multiple_of(K, BLOCK_K)
48+
49+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
50+
for start_k in range(0, prev_multiple, BLOCK_K):
51+
rk = start_k + tl.arange(0, BLOCK_K)
52+
a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
53+
b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
54+
if a.dtype != b.dtype:
55+
a = a.to(C.dtype.element_ty)
56+
b = b.to(C.dtype.element_ty)
57+
acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
58+
59+
# loop peeling
60+
rk = prev_multiple + tl.arange(0, BLOCK_K)
61+
mask_k = rk < K
62+
a = tl.load(
63+
A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
64+
mask=mask_k[None, :],
65+
other=0.0
66+
)
67+
b = tl.load(
68+
B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
69+
mask=mask_k[:, None],
70+
other=0.0
71+
)
72+
if a.dtype != b.dtype:
73+
a = a.to(C.dtype.element_ty)
74+
b = b.to(C.dtype.element_ty)
75+
acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
76+
77+
acc = acc.to(C.dtype.element_ty)
78+
# rematerialize rm and rn to save registers
79+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
80+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
81+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
82+
mask = (rm < M)[:, None] & (rn < N)[None, :]
83+
# handles write-back with reduction-splitting
84+
tl.store(C, acc, mask=mask)
85+
86+
87+
_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
88+
89+
90+
def get_higher_dtype(a, b):
91+
if a is b:
92+
return a
93+
94+
assert a in _ordered_datatypes
95+
assert b in _ordered_datatypes
96+
97+
for d in _ordered_datatypes:
98+
if a is d:
99+
return b
100+
if b is d:
101+
return a
102+
103+
104+
def mm(a, b):
105+
device = a.device
106+
# handle non-contiguous inputs if necessary
107+
if a.stride(0) > 1 and a.stride(1) > 1:
108+
a = a.contiguous()
109+
if b.stride(0) > 1 and b.stride(1) > 1:
110+
b = b.contiguous()
111+
# checks constraints
112+
assert a.shape[1] == b.shape[0], "incompatible dimensions"
113+
M, K = a.shape
114+
_, N = b.shape
115+
# allocates output
116+
c_dtype = get_higher_dtype(a.dtype, b.dtype)
117+
c = torch.empty((M, N), device=device, dtype=c_dtype)
118+
# launch kernel
119+
grid = lambda META: (
120+
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
121+
)
122+
123+
mm_kernel[grid](
124+
a,
125+
b,
126+
c,
127+
M,
128+
N,
129+
K,
130+
a.stride(0),
131+
a.stride(1),
132+
b.stride(0),
133+
b.stride(1),
134+
c.stride(0),
135+
c.stride(1),
136+
32,
137+
32,
138+
32,
139+
GROUP_M=8,
140+
)
141+
142+
return c
143+
144+
@pytest.mark.interpreter
145+
@pytest.mark.parametrize("M, N, K", [(1, 1, 32), (15, 160, 1024), (495, 5333, 71)])
146+
@pytest.mark.parametrize("dtype", [torch.float32])
147+
def test_accuracy_mm(M, N, K, dtype):
148+
device = 'cpu'
149+
a = torch.randn((M, K), dtype=dtype, device=device)
150+
b = torch.randn((K, N), dtype=dtype, device=device)
151+
152+
ref_out = torch.mm(a, b)
153+
res_out = mm(a, b)
154+
155+
torch.testing.assert_close(res_out, ref_out, atol=1e-2, rtol=0)
156+
157+
158+
if __name__ == "__main__":
159+
benchmark.select_cpu_backend()
160+
M, N, K = (495, 5333, 71)
161+
test_accuracy_mm(M, N, K, torch.float32)

test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,18 @@ module {
9696
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
9797
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
9898
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
99-
// CHECK: [[VAR_20_:%.+]] = arith.minsi [[VAR_18_]], [[CST_4_]] : index
100-
// CHECK-DAG: [[VAR_21_:%.+]] = arith.subi [[CST_4_]], [[VAR_20_]] : index
101-
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
99+
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
102100
// CHECK-NOT: separator of consecutive DAGs
103-
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_21_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
104-
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>>
105-
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_20_]]{{.}} [2, [[VAR_21_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>>
101+
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_19_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
102+
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>>
103+
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_18_]]{{.}} [2, [[VAR_19_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>>
106104
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>>
107105
// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>>
108-
// CHECK: [[VAR_22_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
109-
// CHECK: bufferization.materialize_in_destination [[VAR_22_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
110-
// CHECK-DAG: [[VAR_23_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
111-
// CHECK-DAG: [[VAR_24_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_11_]] : index
112-
// CHECK: scf.yield [[VAR_23_]], [[VAR_24_]] : index, index
106+
// CHECK: [[VAR_20_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
107+
// CHECK: bufferization.materialize_in_destination [[VAR_20_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
108+
// CHECK-DAG: [[VAR_21_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
109+
// CHECK-DAG: [[VAR_22_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_11_]] : index
110+
// CHECK: scf.yield [[VAR_21_]], [[VAR_22_]] : index, index
113111
// CHECK: }
114112
// CHECK: return
115113
// CHECK: }

test/Conversion/StructuredToMemref/wraparound_stacked.mlir

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,26 +83,25 @@ module {
8383
// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_3_]], [[VAR_12_]] : index
8484
// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_11_]] : index
8585
// CHECK: [[VAR_15_:%.+]] = arith.divsi [[VAR_14_]], [[VAR_1_]] : index
86-
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_15_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
87-
// CHECK-DAG: [[VAR_16_:%.+]] = arith.subi [[CST_4_]], [[VAR_15_]] : index
86+
// CHECK: [[VAR_16_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index
87+
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
88+
// CHECK-DAG: [[VAR_17_:%.+]] = arith.subi [[CST_4_]], [[VAR_16_]] : index
8889
// CHECK-NOT: separator of consecutive DAGs
89-
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
90+
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_17_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
9091
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
9192
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
92-
// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index
93-
// CHECK-DAG: [[VAR_18_:%.+]] = arith.subi [[CST_4_]], [[VAR_17_]] : index
94-
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
93+
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
9594
// CHECK-NOT: separator of consecutive DAGs
96-
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
97-
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1]>>
98-
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_17_]], 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1], offset: ?>>
95+
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
96+
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1]>>
97+
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_16_]], 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1], offset: ?>>
9998
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<?x3xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[4, 1]>>
10099
// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<?x3xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[4, 1], offset: ?>>
101-
// CHECK: [[VAR_19_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
102-
// CHECK: bufferization.materialize_in_destination [[VAR_19_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
103-
// CHECK-DAG: [[VAR_20_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
104-
// CHECK-DAG: [[VAR_21_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_9_]] : index
105-
// CHECK: scf.yield [[VAR_20_]], [[VAR_21_]] : index, index
100+
// CHECK: [[VAR_18_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
101+
// CHECK: bufferization.materialize_in_destination [[VAR_18_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
102+
// CHECK-DAG: [[VAR_19_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
103+
// CHECK-DAG: [[VAR_20_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_9_]] : index
104+
// CHECK: scf.yield [[VAR_19_]], [[VAR_20_]] : index, index
106105
// CHECK: }
107106
// CHECK: return
108107
// CHECK: }

0 commit comments

Comments
 (0)