Skip to content

Commit feb1139

Browse files
More test cleanup
1 parent 5102319 commit feb1139

2 files changed

Lines changed: 169 additions & 171 deletions

File tree

tests/test_functional.py

Lines changed: 166 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,9 @@ def test_approx_igemm(self, dim1, dim2, quant_methods, batched):
369369
# print(mean(errors))
370370
# print(mean(relerrors))
371371

372-
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
373-
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
374-
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
372+
@pytest.mark.parametrize("hidden_dim", [32, 256], ids=id_formatter("hidden_dim"))
373+
@pytest.mark.parametrize("batch_dim", [16, 256], ids=id_formatter("batch_dim"))
374+
@pytest.mark.parametrize("seq_dim", [16, 256], ids=id_formatter("seq_dim"))
375375
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
376376
def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
377377
hidden_dim = hidden_dim - (hidden_dim % 32)
@@ -415,9 +415,9 @@ def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
415415

416416
torch.testing.assert_close(out.float(), out2)
417417

418-
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
419-
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
420-
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
418+
@pytest.mark.parametrize("seq_dim", [32, 256, 512], ids=id_formatter("seq_dim"))
419+
@pytest.mark.parametrize("hidden_dim", [64, 1024, 4096], ids=id_formatter("hidden_dim"))
420+
@pytest.mark.parametrize("batch_dim", [2, 8, 16], ids=id_formatter("batch_dim"))
421421
def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
422422
seq_dim = seq_dim - (seq_dim % 32)
423423
hidden_dim = hidden_dim - (hidden_dim % 32)
@@ -431,9 +431,9 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
431431

432432
torch.testing.assert_close(out.float(), out2)
433433

434-
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
435-
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
436-
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
434+
@pytest.mark.parametrize("seq_dim", [32, 512], ids=id_formatter("seq_dim"))
435+
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
436+
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
437437
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
438438
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
439439
def min_max(x):
@@ -501,10 +501,10 @@ def min_max(x):
501501
assert mean(errs) < 0.015
502502
assert mean(relerrs) < 0.3
503503

504-
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
505-
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
506-
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
507-
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
504+
@pytest.mark.parametrize("dim1", [1, 64], ids=id_formatter("dim1"))
505+
@pytest.mark.parametrize("dim2", [32, 128], ids=id_formatter("dim2"))
506+
@pytest.mark.parametrize("dim3", [32, 256], ids=id_formatter("dim3"))
507+
@pytest.mark.parametrize("dim4", [32, 256], ids=id_formatter("dim4"))
508508
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
509509
def test_ibmm(self, dim1, dim2, dim3, dim4, transpose):
510510
dim2 = dim2 - (dim2 % 16)
@@ -760,8 +760,8 @@ def test_coo_int8_vectorwise_quant(self, dim1, dim2):
760760

761761

762762
class TestSpMMFunctional:
763-
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
764-
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
763+
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
764+
@pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2"))
765765
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
766766
def test_spmm_coo(self, dim1, dim2, transposed_B):
767767
threshold = 1.5
@@ -1096,37 +1096,34 @@ def test_4bit_quant(self, dtype, quant_type, blocksize):
10961096
assert err.item() < math.log2(blocksize) * 8e-2
10971097

10981098
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1099-
def test_4bit_compressed_stats(self, quant_type):
1100-
for blocksize in [128, 64]:
1101-
errs1 = []
1102-
errs2 = []
1103-
for i in range(10):
1104-
A1 = torch.randn(1024, 1024, device="cuda").half()
1105-
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1106-
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
1107-
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
1108-
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
1109-
1110-
err = (A1 - A2).abs().float()
1111-
relerr = (err / (A1.abs().float() + 1e-15)).mean()
1112-
err = err.mean()
1099+
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
1100+
def test_4bit_compressed_stats(self, quant_type, blocksize):
1101+
errs1 = []
1102+
errs2 = []
1103+
for i in range(10):
1104+
A1 = torch.randn(1024, 1024, device="cuda").half()
1105+
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1106+
q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
1107+
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
1108+
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
11131109

1114-
errs1.append(err.item())
1110+
err = (A1 - A2).abs().float()
1111+
relerr = (err / (A1.abs().float() + 1e-15)).mean()
1112+
err = err.mean()
11151113

1116-
assert err.item() < 0.11
1117-
assert relerr.item() < 0.28
1114+
errs1.append(err.item())
11181115

1119-
err = (A1 - A3).abs().float()
1120-
relerr = (err / (A1.abs().float() + 1e-15)).mean()
1121-
err = err.mean()
1116+
assert err.item() < 0.11
1117+
assert relerr.item() < 0.28
11221118

1123-
errs2.append(err.item())
1119+
err = (A1 - A3).abs().float()
1120+
relerr = (err / (A1.abs().float() + 1e-15)).mean()
1121+
err = err.mean()
11241122

1125-
assert err.item() < 0.11
1126-
assert relerr.item() < 0.28
1123+
errs2.append(err.item())
11271124

1128-
# print(sum(errs1)/len(errs1), blocksize, quant_type)
1129-
# print(sum(errs2)/len(errs2), blocksize, quant_type)
1125+
assert err.item() < 0.11
1126+
assert relerr.item() < 0.28
11301127

11311128
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
11321129
@pytest.mark.parametrize("quant_type", ["nf4"])
@@ -1169,135 +1166,133 @@ def test_bench_4bit_dequant(self, quant_type):
11691166
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
11701167
ids=describe_dtype,
11711168
)
1172-
def test_gemv_4bit(self, dtype, storage_type, quant_storage, double_quant, kind):
1173-
for dim in [128, 256, 512, 1024]:
1174-
# for dim in [4*1024]:
1175-
# for dim in [1*16]:
1176-
errs1 = []
1177-
errs2 = []
1178-
errs3 = []
1179-
relerrs1 = []
1180-
relerrs2 = []
1181-
relerrs3 = []
1182-
max_errs1 = []
1183-
max_errs2 = []
1184-
max_errs3 = []
1169+
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
1170+
def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, kind):
1171+
errs1 = []
1172+
errs2 = []
1173+
errs3 = []
1174+
relerrs1 = []
1175+
relerrs2 = []
1176+
relerrs3 = []
1177+
max_errs1 = []
1178+
max_errs2 = []
1179+
max_errs3 = []
11851180

1186-
for i in range(100):
1187-
if kind == "fc1":
1188-
A = torch.randn(1, dim, dtype=dtype, device="cuda")
1189-
B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1190-
elif kind == "fc2":
1191-
A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
1192-
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1193-
elif kind == "attn":
1194-
A = torch.randn(1, dim, dtype=dtype, device="cuda")
1195-
B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1196-
elif kind == "attn_packed":
1197-
A = torch.randn(1, dim, dtype=dtype, device="cuda")
1198-
B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1199-
1200-
qB, state = F.quantize_4bit(
1201-
B,
1202-
quant_type=storage_type,
1203-
compress_statistics=double_quant,
1204-
quant_storage=quant_storage,
1205-
)
1206-
C3 = torch.matmul(A, B.t())
1207-
C2 = F.gemv_4bit(A, qB.t(), state=state)
1208-
A.requires_grad = True
1209-
C1 = bnb.matmul_4bit(A, qB.t(), state)
1210-
1211-
err1 = (C1 - C2).abs().float()
1212-
err2 = (C3 - C2).abs().float()
1213-
err3 = (C3 - C1).abs().float()
1214-
1215-
mag1 = torch.abs(C1).float() + 1e-5
1216-
mag2 = torch.abs(C3).float() + 1e-5
1217-
mag3 = torch.abs(C3).float() + 1e-5
1218-
1219-
relerr1 = err1 / mag1
1220-
relerr2 = err2 / mag2
1221-
relerr3 = err3 / mag3
1222-
1223-
max_err1 = err1.max()
1224-
max_err2 = err2.max()
1225-
max_err3 = err3.max()
1226-
1227-
errs1.append(err1.mean().item())
1228-
errs2.append(err2.mean().item())
1229-
errs3.append(err3.mean().item())
1230-
1231-
relerrs1.append(relerr1.mean().item())
1232-
relerrs2.append(relerr2.mean().item())
1233-
relerrs3.append(relerr3.mean().item())
1234-
1235-
max_errs1.append(max_err1.item())
1236-
max_errs2.append(max_err2.item())
1237-
max_errs3.append(max_err3.item())
1238-
1239-
c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
1240-
1241-
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
1242-
err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
1243-
err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
1244-
err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
1245-
relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
1246-
relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
1247-
relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
1248-
maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
1249-
maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
1250-
maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
1251-
absratio = err2 / err3
1252-
relratio = relerr2 / relerr3
1253-
maxratio = relerr2 / relerr3
1254-
1255-
# for debugging if the tests fails
1256-
#
1257-
# print('='*80)
1258-
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
1259-
# print(C1.flatten()[-20:])
1260-
# print(C2.flatten()[-20:])
1261-
# print(f'inference vs training abs: {err1}')
1262-
# print(f'inference vs training rel: {relerr1}')
1263-
# print(f'inference vs training max: {maxerr1}')
1264-
# print(f'inference vs training vs torch err ratio abs: {absratio}')
1265-
# print(f'inference vs training vs torch err ratio rel: {relratio}')
1266-
# print(f'inference vs training vs torch err ratio max: {maxratio}')
1267-
if dtype == torch.float16:
1268-
if dim <= 512:
1269-
assert err1 < 7e-5
1270-
assert relerr1 < 0.0008
1271-
else:
1272-
assert err1 < 6e-5
1273-
assert relerr1 < 2e-4
1274-
assert absratio < 1.005 and absratio > 0.995
1275-
assert relratio < 1.005 and relratio > 0.995
1276-
assert maxratio < 1.005 and maxratio > 0.995
1277-
elif dtype == torch.float32:
1278-
if dim <= 512:
1279-
assert err1 < 5e-8
1280-
assert relerr1 < 1e-6
1281-
assert maxerr1 < 1e-7
1282-
else:
1283-
assert err1 < 5e-8
1284-
assert relerr1 < 8e-6
1285-
assert maxerr1 < 1e-7
1286-
assert absratio < 1.005 and absratio > 0.995
1287-
assert relratio < 1.005 and relratio > 0.995
1288-
assert maxratio < 1.005 and maxratio > 0.995
1289-
elif dtype == torch.bfloat16:
1290-
if dim <= 512:
1291-
assert err1 < 6e-4
1292-
assert relerr1 < 0.007
1293-
assert maxerr1 < 0.015
1294-
else:
1295-
assert err1 < 2e-4
1296-
assert relerr1 < 0.002
1297-
assert maxerr1 < 0.0012
1298-
assert absratio < 1.005 and absratio > 0.995
1299-
assert relratio < 1.04 and relratio > 0.96
1300-
assert maxratio < 1.02 and maxratio > 0.98
1181+
for i in range(100):
1182+
if kind == "fc1":
1183+
A = torch.randn(1, dim, dtype=dtype, device="cuda")
1184+
B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1185+
elif kind == "fc2":
1186+
A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
1187+
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1188+
elif kind == "attn":
1189+
A = torch.randn(1, dim, dtype=dtype, device="cuda")
1190+
B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1191+
elif kind == "attn_packed":
1192+
A = torch.randn(1, dim, dtype=dtype, device="cuda")
1193+
B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
1194+
1195+
qB, state = F.quantize_4bit(
1196+
B,
1197+
quant_type=storage_type,
1198+
compress_statistics=double_quant,
1199+
quant_storage=quant_storage,
1200+
)
1201+
C3 = torch.matmul(A, B.t())
1202+
C2 = F.gemv_4bit(A, qB.t(), state=state)
1203+
A.requires_grad = True
1204+
C1 = bnb.matmul_4bit(A, qB.t(), state)
1205+
1206+
err1 = (C1 - C2).abs().float()
1207+
err2 = (C3 - C2).abs().float()
1208+
err3 = (C3 - C1).abs().float()
1209+
1210+
mag1 = torch.abs(C1).float() + 1e-5
1211+
mag2 = torch.abs(C3).float() + 1e-5
1212+
mag3 = torch.abs(C3).float() + 1e-5
1213+
1214+
relerr1 = err1 / mag1
1215+
relerr2 = err2 / mag2
1216+
relerr3 = err3 / mag3
1217+
1218+
max_err1 = err1.max()
1219+
max_err2 = err2.max()
1220+
max_err3 = err3.max()
1221+
1222+
errs1.append(err1.mean().item())
1223+
errs2.append(err2.mean().item())
1224+
errs3.append(err3.mean().item())
1225+
1226+
relerrs1.append(relerr1.mean().item())
1227+
relerrs2.append(relerr2.mean().item())
1228+
relerrs3.append(relerr3.mean().item())
1229+
1230+
max_errs1.append(max_err1.item())
1231+
max_errs2.append(max_err2.item())
1232+
max_errs3.append(max_err3.item())
1233+
1234+
c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
1235+
1236+
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
1237+
err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
1238+
err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
1239+
err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
1240+
relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
1241+
relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
1242+
relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
1243+
maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
1244+
maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
1245+
maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
1246+
absratio = err2 / err3
1247+
relratio = relerr2 / relerr3
1248+
maxratio = relerr2 / relerr3
1249+
1250+
# for debugging if the tests fails
1251+
#
1252+
# print('='*80)
1253+
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
1254+
# print(C1.flatten()[-20:])
1255+
# print(C2.flatten()[-20:])
1256+
# print(f'inference vs training abs: {err1}')
1257+
# print(f'inference vs training rel: {relerr1}')
1258+
# print(f'inference vs training max: {maxerr1}')
1259+
# print(f'inference vs training vs torch err ratio abs: {absratio}')
1260+
# print(f'inference vs training vs torch err ratio rel: {relratio}')
1261+
# print(f'inference vs training vs torch err ratio max: {maxratio}')
1262+
if dtype == torch.float16:
1263+
if dim <= 512:
1264+
assert err1 < 7e-5
1265+
assert relerr1 < 0.0008
1266+
else:
1267+
assert err1 < 6e-5
1268+
assert relerr1 < 2e-4
1269+
assert absratio < 1.005 and absratio > 0.995
1270+
assert relratio < 1.005 and relratio > 0.995
1271+
assert maxratio < 1.005 and maxratio > 0.995
1272+
elif dtype == torch.float32:
1273+
if dim <= 512:
1274+
assert err1 < 5e-8
1275+
assert relerr1 < 1e-6
1276+
assert maxerr1 < 1e-7
1277+
else:
1278+
assert err1 < 5e-8
1279+
assert relerr1 < 8e-6
1280+
assert maxerr1 < 1e-7
1281+
assert absratio < 1.005 and absratio > 0.995
1282+
assert relratio < 1.005 and relratio > 0.995
1283+
assert maxratio < 1.005 and maxratio > 0.995
1284+
elif dtype == torch.bfloat16:
1285+
if dim <= 512:
1286+
assert err1 < 6e-4
1287+
assert relerr1 < 0.007
1288+
assert maxerr1 < 0.015
1289+
else:
1290+
assert err1 < 2e-4
1291+
assert relerr1 < 0.002
1292+
assert maxerr1 < 0.0012
1293+
assert absratio < 1.005 and absratio > 0.995
1294+
assert relratio < 1.04 and relratio > 0.96
1295+
assert maxratio < 1.02 and maxratio > 0.98
13011296

13021297
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
13031298
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@@ -1363,9 +1358,9 @@ def test_managed():
13631358
assert (A == 17 * (2**3)).sum().item() == n * n
13641359

13651360

1366-
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
1367-
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
1368-
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
1361+
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
1362+
@pytest.mark.parametrize("dim2", [64], ids=id_formatter("dim2"))
1363+
@pytest.mark.parametrize("dim3", [128], ids=id_formatter("dim3"))
13691364
@pytest.mark.deprecated
13701365
def test_vector_quant(dim1, dim2, dim3):
13711366
dim2 = dim2 - (dim2 % 16)

0 commit comments

Comments
 (0)