|
| 1 | +import contextlib |
1 | 2 | import inspect |
| 3 | +import logging |
2 | 4 |
|
3 | 5 | import pytest |
4 | 6 | import torch |
|
8 | 10 | from tests.helpers import get_available_devices, id_formatter, is_supported_on_hpu |
9 | 11 |
|
10 | 12 |
|
| 13 | +@contextlib.contextmanager |
| 14 | +def caplog_at_level(caplog, level, logger_name): |
| 15 | + with caplog.at_level(level, logger=logger_name): |
| 16 | + yield |
| 17 | + |
| 18 | + |
11 | 19 | class MockArgs: |
12 | 20 | def __init__(self, initial_data): |
13 | 21 | for key in initial_data: |
@@ -453,46 +461,38 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu |
453 | 461 |
|
454 | 462 |
|
455 | 463 | @pytest.mark.parametrize("device", get_available_devices()) |
456 | | -def test_4bit_linear_warnings(device): |
| 464 | +def test_4bit_linear_warnings(device, caplog): |
457 | 465 | dim1 = 64 |
458 | 466 |
|
459 | | - with pytest.warns(UserWarning, match=r"inference or training"): |
460 | | - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)]) |
461 | | - net = net.to(device) |
462 | | - inp = torch.rand(10, dim1, device=device, dtype=torch.float16) |
463 | | - net(inp) |
464 | | - with pytest.warns(UserWarning, match=r"inference."): |
465 | | - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)]) |
466 | | - net = net.to(device) |
467 | | - inp = torch.rand(1, dim1, device=device, dtype=torch.float16) |
468 | | - net(inp) |
469 | | - |
470 | | - with pytest.warns(UserWarning) as record: |
| 467 | + with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"): |
471 | 468 | net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)]) |
472 | 469 | net = net.to(device) |
473 | 470 | inp = torch.rand(10, dim1, device=device, dtype=torch.float16) |
474 | 471 | net(inp) |
| 472 | + assert any("inference or training" in msg for msg in caplog.messages) |
475 | 473 |
|
| 474 | + caplog.clear() |
| 475 | + with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"): |
476 | 476 | net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)]) |
477 | 477 | net = net.to(device) |
478 | 478 | inp = torch.rand(1, dim1, device=device, dtype=torch.float16) |
479 | 479 | net(inp) |
480 | | - |
481 | | - assert len(record) == 2 |
| 480 | + assert any("inference." in msg for msg in caplog.messages) |
482 | 481 |
|
483 | 482 |
|
484 | 483 | @pytest.mark.parametrize("device", get_available_devices()) |
485 | | -def test_4bit_embedding_warnings(device): |
| 484 | +def test_4bit_embedding_warnings(device, caplog): |
486 | 485 | num_embeddings = 128 |
487 | 486 | default_block_size = 64 |
488 | 487 |
|
489 | | - with pytest.warns(UserWarning, match=r"inference."): |
| 488 | + with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"): |
490 | 489 | net = bnb.nn.Embedding4bit( |
491 | 490 | num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4" |
492 | 491 | ) |
493 | 492 | net.to(device) |
494 | 493 | inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device) |
495 | 494 | net(inp) |
| 495 | + assert any("inference" in msg for msg in caplog.messages) |
496 | 496 |
|
497 | 497 |
|
498 | 498 | def test_4bit_embedding_weight_fsdp_fix(requires_cuda): |
|
0 commit comments