Skip to content

Commit 5102319

Browse files
Testing cleanup
1 parent b86ff64 commit 5102319

8 files changed

Lines changed: 161 additions & 130 deletions

File tree

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Extracted from tests/test_optim.py
3+
4+
Usage: pytest benchmarking/optimizer_benchmark.py
5+
"""
6+
7+
import time
8+
9+
import pytest
10+
from tests.helpers import describe_dtype, id_formatter
11+
import torch
12+
13+
import bitsandbytes as bnb
14+
15+
str2optimizers = {"paged_adamw": (torch.optim.AdamW, bnb.optim.PagedAdamW)}
16+
17+
18+
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
19+
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
20+
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
21+
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
22+
@pytest.mark.benchmark
23+
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
24+
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
25+
layers1 = layers1.to(gtype)
26+
layers1 = layers1.cuda()
27+
28+
large_tensor = None
29+
if mode == "torch":
30+
optim = str2optimizers[optim_name][0](layers1.parameters())
31+
else:
32+
optim = str2optimizers[optim_name][1](layers1.parameters())
33+
# 12 GB
34+
large_tensor = torch.empty((int(4.5e9),), device="cuda")
35+
36+
torch.cuda.synchronize()
37+
time.sleep(5)
38+
39+
num_batches = 5
40+
batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
41+
lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
42+
43+
for i in range(num_batches):
44+
print(i)
45+
b = batches[i]
46+
if i == 2:
47+
torch.cuda.synchronize()
48+
t0 = time.time()
49+
50+
out1 = layers1(b)
51+
52+
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
53+
loss1.backward()
54+
optim.step()
55+
torch.cuda.synchronize()
56+
print(mode, time.time() - t0)

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
import gc
2+
import random
23

4+
import numpy as np
35
import pytest
46
import torch
57

68

9+
def _set_seed():
10+
torch.manual_seed(0)
11+
torch.cuda.manual_seed_all(0)
12+
torch.mps.manual_seed(0)
13+
np.random.seed(0)
14+
random.seed(0)
15+
16+
717
def pytest_runtest_call(item):
818
try:
19+
_set_seed()
920
item.runtest()
1021
except AssertionError as ae:
1122
if str(ae) == "Torch not compiled with CUDA enabled":

tests/test_autograd.py

Lines changed: 4 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
BOOLEAN_TRIPLES,
77
TRUE_FALSE,
88
describe_dtype,
9-
get_test_dims,
109
id_formatter,
1110
)
1211

@@ -136,10 +135,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
136135
torch.testing.assert_close(gradBias1, gradBias2)
137136

138137

139-
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
140-
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
141-
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
142-
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
138+
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
139+
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
140+
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
141+
@pytest.mark.parametrize("dim4", [96], ids=id_formatter("dim4"))
143142
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"])
144143
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
145144
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@@ -231,85 +230,3 @@ def test_matmul_4bit(
231230

232231
if req_grad[2]:
233232
torch.testing.assert_close(gradBias1, gradBias2)
234-
235-
236-
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
237-
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
238-
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
239-
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
240-
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
241-
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
242-
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
243-
@pytest.mark.parametrize(
244-
"funcs",
245-
[(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
246-
ids=["matmul_fp8_mixed", "matmul_fp8_global"],
247-
)
248-
def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
249-
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
250-
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
251-
req_grad = list(req_grad)
252-
req_grad[2] = False
253-
254-
for i in range(3):
255-
# normal multiply
256-
if funcs[0] in [torch.mm, torch.matmul]:
257-
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
258-
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
259-
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
260-
261-
torch.nn.init.xavier_uniform_(B)
262-
263-
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
264-
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
265-
266-
if not transpose[0] and transpose[1]:
267-
out_torch = funcs[0](A, B.t())
268-
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
269-
elif not transpose[0] and not transpose[1]:
270-
out_torch = funcs[0](A, B)
271-
out_bnb = funcs[1](A, B, fw_code, bw_code)
272-
273-
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
274-
275-
n = out_bnb.numel()
276-
err = torch.abs(out_bnb - out_torch).float().mean().item()
277-
if n > 0:
278-
assert err < 0.115
279-
# assert err < 0.20
280-
if any(req_grad):
281-
out_bnb.data.copy_(out_torch)
282-
torch.cuda.synchronize()
283-
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
284-
loss_bnb.backward()
285-
gradA1 = A.grad
286-
gradB1 = B.grad
287-
A.grad = None
288-
B.grad = None
289-
290-
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
291-
loss_torch.backward()
292-
gradA2 = A.grad
293-
gradB2 = B.grad
294-
A.grad = None
295-
B.grad = None
296-
297-
if req_grad[0]:
298-
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
299-
300-
if req_grad[1]:
301-
n = gradB1.numel()
302-
if dim2 > 0:
303-
assert torch.abs(gradB1).sum() > 0.0
304-
assert torch.abs(gradB2).sum() > 0.0
305-
else:
306-
assert torch.abs(gradB1).sum() == 0.0
307-
assert torch.abs(gradB2).sum() == 0.0
308-
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
309-
310-
assert (idx == 0).sum().item() <= n * 0.1
311-
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
312-
assert (idx == 0).sum().item() <= n * 0.02
313-
grad_err = (gradB1 - gradB2).abs().mean()
314-
assert grad_err.item() < 0.003
315-
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)

tests/test_deprecated.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from scipy.stats import norm
44
import torch
55

6+
import bitsandbytes as bnb
67
from bitsandbytes import functional as F
8+
from tests.helpers import BOOLEAN_TRIPLES, describe_dtype, get_test_dims, id_formatter
9+
from tests.test_autograd import TRANSPOSE_VALS
710

811

912
@pytest.mark.deprecated
@@ -121,3 +124,87 @@ def test_percentile_clipping(gtype):
121124
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
122125
torch.testing.assert_close(clip1, clip2)
123126
torch.testing.assert_close(gnorm1, gnorm2)
127+
128+
129+
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
130+
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
131+
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
132+
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
133+
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
134+
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
135+
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
136+
@pytest.mark.parametrize(
137+
"funcs",
138+
[(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)],
139+
ids=["matmul_fp8_mixed", "matmul_fp8_global"],
140+
)
141+
@pytest.mark.deprecated
142+
@pytest.mark.skip("Deprecated functionality, to be removed.")
143+
def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
144+
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
145+
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
146+
req_grad = list(req_grad)
147+
req_grad[2] = False
148+
149+
for i in range(3):
150+
# normal multiply
151+
if funcs[0] in [torch.mm, torch.matmul]:
152+
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
153+
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
154+
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
155+
156+
torch.nn.init.xavier_uniform_(B)
157+
158+
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
159+
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
160+
161+
if not transpose[0] and transpose[1]:
162+
out_torch = funcs[0](A, B.t())
163+
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
164+
elif not transpose[0] and not transpose[1]:
165+
out_torch = funcs[0](A, B)
166+
out_bnb = funcs[1](A, B, fw_code, bw_code)
167+
168+
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
169+
170+
n = out_bnb.numel()
171+
err = torch.abs(out_bnb - out_torch).float().mean().item()
172+
if n > 0:
173+
assert err < 0.115
174+
# assert err < 0.20
175+
if any(req_grad):
176+
out_bnb.data.copy_(out_torch)
177+
torch.cuda.synchronize()
178+
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
179+
loss_bnb.backward()
180+
gradA1 = A.grad
181+
gradB1 = B.grad
182+
A.grad = None
183+
B.grad = None
184+
185+
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
186+
loss_torch.backward()
187+
gradA2 = A.grad
188+
gradB2 = B.grad
189+
A.grad = None
190+
B.grad = None
191+
192+
if req_grad[0]:
193+
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
194+
195+
if req_grad[1]:
196+
n = gradB1.numel()
197+
if dim2 > 0:
198+
assert torch.abs(gradB1).sum() > 0.0
199+
assert torch.abs(gradB2).sum() > 0.0
200+
else:
201+
assert torch.abs(gradB1).sum() == 0.0
202+
assert torch.abs(gradB2).sum() == 0.0
203+
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
204+
205+
assert (idx == 0).sum().item() <= n * 0.1
206+
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
207+
assert (idx == 0).sum().item() <= n * 0.02
208+
grad_err = (gradB1 - gradB2).abs().mean()
209+
assert grad_err.item() < 0.003
210+
torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)

tests/test_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func):
893893

894894
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
895895
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
896-
@pytest.skip("No longer supported")
896+
@pytest.mark.skip("No longer supported")
897897
def test_integrated_sparse_decomp(self, dim1, dim2):
898898
threshold = 3.0
899899
for _ in range(k):

tests/test_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
6060
return tokenizer.decode(outputs[0], skip_special_tokens=True)
6161

6262

63-
models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"]
63+
models = ["bigscience/bloom-1b7"]
6464
dtypes = ["nf4", "fp4"]
6565

6666

tests/test_optim.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -604,44 +604,3 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
604604
params = (total_steps - total_steps // 5) * dim1 * dim2
605605
print(optim_name, gtype, s, params, s / params)
606606
# assert s < 3.9
607-
608-
609-
@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1"))
610-
@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype)
611-
@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name"))
612-
@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode"))
613-
@pytest.mark.benchmark
614-
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
615-
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
616-
layers1 = layers1.to(gtype)
617-
layers1 = layers1.cuda()
618-
619-
large_tensor = None
620-
if mode == "torch":
621-
optim = str2optimizers[optim_name][0](layers1.parameters())
622-
else:
623-
optim = str2optimizers[optim_name][1](layers1.parameters())
624-
# 12 GB
625-
large_tensor = torch.empty((int(4.5e9),), device="cuda")
626-
627-
torch.cuda.synchronize()
628-
time.sleep(5)
629-
630-
num_batches = 5
631-
batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype)
632-
lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda()
633-
634-
for i in range(num_batches):
635-
print(i)
636-
b = batches[i]
637-
if i == 2:
638-
torch.cuda.synchronize()
639-
t0 = time.time()
640-
641-
out1 = layers1(b)
642-
643-
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
644-
loss1.backward()
645-
optim.step()
646-
torch.cuda.synchronize()
647-
print(mode, time.time() - t0)

tests/test_triton.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
1212
reason="This test requires triton and a GPU with compute capability 8.0 or higher.",
1313
)
14+
@pytest.mark.skip("No longer supported.")
1415
@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE)
1516
def test_switchback(vector_wise_quantization):
1617
for dim in [83]:

0 commit comments

Comments
 (0)