|
6 | 6 | BOOLEAN_TRIPLES, |
7 | 7 | TRUE_FALSE, |
8 | 8 | describe_dtype, |
9 | | - get_test_dims, |
10 | 9 | id_formatter, |
11 | 10 | ) |
12 | 11 |
|
@@ -136,10 +135,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec |
136 | 135 | torch.testing.assert_close(gradBias1, gradBias2) |
137 | 136 |
|
138 | 137 |
|
139 | | -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) |
140 | | -@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) |
141 | | -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) |
142 | | -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) |
| 138 | +@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1")) |
| 139 | +@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) |
| 140 | +@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3")) |
| 141 | +@pytest.mark.parametrize("dim4", [96], ids=id_formatter("dim4")) |
143 | 142 | @pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"]) |
144 | 143 | @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) |
145 | 144 | @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) |
@@ -231,85 +230,3 @@ def test_matmul_4bit( |
231 | 230 |
|
232 | 231 | if req_grad[2]: |
233 | 232 | torch.testing.assert_close(gradBias1, gradBias2) |
234 | | - |
235 | | - |
236 | | -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) |
237 | | -@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) |
238 | | -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) |
239 | | -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) |
240 | | -@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) |
241 | | -@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) |
242 | | -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) |
243 | | -@pytest.mark.parametrize( |
244 | | - "funcs", |
245 | | - [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], |
246 | | - ids=["matmul_fp8_mixed", "matmul_fp8_global"], |
247 | | -) |
248 | | -def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): |
249 | | - dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) |
250 | | - dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) |
251 | | - req_grad = list(req_grad) |
252 | | - req_grad[2] = False |
253 | | - |
254 | | - for i in range(3): |
255 | | - # normal multiply |
256 | | - if funcs[0] in [torch.mm, torch.matmul]: |
257 | | - A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) |
258 | | - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) |
259 | | - target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) |
260 | | - |
261 | | - torch.nn.init.xavier_uniform_(B) |
262 | | - |
263 | | - fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) |
264 | | - bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) |
265 | | - |
266 | | - if not transpose[0] and transpose[1]: |
267 | | - out_torch = funcs[0](A, B.t()) |
268 | | - out_bnb = funcs[1](A, B.t(), fw_code, bw_code) |
269 | | - elif not transpose[0] and not transpose[1]: |
270 | | - out_torch = funcs[0](A, B) |
271 | | - out_bnb = funcs[1](A, B, fw_code, bw_code) |
272 | | - |
273 | | - assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" |
274 | | - |
275 | | - n = out_bnb.numel() |
276 | | - err = torch.abs(out_bnb - out_torch).float().mean().item() |
277 | | - if n > 0: |
278 | | - assert err < 0.115 |
279 | | - # assert err < 0.20 |
280 | | - if any(req_grad): |
281 | | - out_bnb.data.copy_(out_torch) |
282 | | - torch.cuda.synchronize() |
283 | | - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() |
284 | | - loss_bnb.backward() |
285 | | - gradA1 = A.grad |
286 | | - gradB1 = B.grad |
287 | | - A.grad = None |
288 | | - B.grad = None |
289 | | - |
290 | | - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() |
291 | | - loss_torch.backward() |
292 | | - gradA2 = A.grad |
293 | | - gradB2 = B.grad |
294 | | - A.grad = None |
295 | | - B.grad = None |
296 | | - |
297 | | - if req_grad[0]: |
298 | | - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) |
299 | | - |
300 | | - if req_grad[1]: |
301 | | - n = gradB1.numel() |
302 | | - if dim2 > 0: |
303 | | - assert torch.abs(gradB1).sum() > 0.0 |
304 | | - assert torch.abs(gradB2).sum() > 0.0 |
305 | | - else: |
306 | | - assert torch.abs(gradB1).sum() == 0.0 |
307 | | - assert torch.abs(gradB2).sum() == 0.0 |
308 | | - idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) |
309 | | - |
310 | | - assert (idx == 0).sum().item() <= n * 0.1 |
311 | | - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) |
312 | | - assert (idx == 0).sum().item() <= n * 0.02 |
313 | | - grad_err = (gradB1 - gradB2).abs().mean() |
314 | | - assert grad_err.item() < 0.003 |
315 | | - torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) |
0 commit comments