Skip to content

Commit ef087bc

Browse files
committed
lint
1 parent 951ccfa commit ef087bc

3 files changed

Lines changed: 45 additions & 50 deletions

File tree

bitsandbytes/nn/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
# LICENSE file in the root directory of this source tree.
55
from .modules import (
66
Embedding,
7+
Embedding4bit,
8+
Embedding8bit,
9+
EmbeddingFP4,
10+
EmbeddingNF4,
711
Int8Params,
812
Linear4bit,
913
Linear8bitLt,
1014
LinearFP4,
1115
LinearNF4,
12-
Embedding8bit,
13-
Embedding4bit,
14-
EmbeddingFP4,
15-
EmbeddingNF4,
1616
OutlierAwareLinear,
1717
Params4bit,
1818
StableEmbedding,

bitsandbytes/nn/modules.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
462462

463463
def forward(self, x: torch.Tensor):
464464
fix_4bit_weight_quant_state_from_module(self)
465-
465+
466466
# weights are cast automatically as Int8Params, but the bias has to be cast manually
467467
if self.bias is not None and self.bias.dtype != x.dtype:
468468
self.bias.data = self.bias.data.to(x.dtype)
@@ -679,6 +679,7 @@ class Embedding8bit(nn.Embedding):
679679
int8_module = int8_module.to(0) # Quantization happens here
680680
```
681681
"""
682+
682683
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
683684
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
684685
self.dtype = self.weight.data.dtype
@@ -689,10 +690,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
689690
raise NotImplementedError("saving Embedding4bit module is not implemented")
690691

691692
def forward(self, input: Tensor) -> Tensor:
692-
if not hasattr(self.weight, 'SCB'):
693-
raise RuntimeError(
694-
"Embedding layer is not quantized. Please call .cuda() or .to(device) first."
695-
)
693+
if not hasattr(self.weight, "SCB"):
694+
raise RuntimeError("Embedding layer is not quantized. Please call .cuda() or .to(device) first.")
696695

697696
rows = self.weight.data
698697
row_stats = self.weight.SCB
@@ -728,6 +727,7 @@ class Embedding4bit(nn.Embedding):
728727
quantized_module = quantized_module.to(0) # Quantization happens here
729728
```
730729
"""
730+
731731
def __init__(
732732
self,
733733
num_embeddings,
@@ -757,22 +757,17 @@ def __init__(
757757
"This will lead to slow inference.",
758758
)
759759

760-
761760
def _forward_with_partial_dequantize(self, input: Tensor):
762761
assert self.embedding_dim % self.weight.quant_state.blocksize == 0
763762

764-
w_4bit_uint8 = (
765-
self.weight.data.view(torch.uint8)
766-
.view(self.num_embeddings * self.embedding_dim // 2, 1)
767-
)
763+
w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1)
768764

769765
output_4bit = torch.nn.functional.embedding(
770766
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
771767
input=input,
772768
).view(-1, 1)
773769
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)
774770

775-
776771
blocks_per_emb = self.embedding_dim // self.weight.blocksize
777772

778773
absmax = self.weight.quant_state.absmax
@@ -781,16 +776,16 @@ def _forward_with_partial_dequantize(self, input: Tensor):
781776
output_absmax = torch.nn.functional.embedding(
782777
weight=absmax.view(self.num_embeddings, blocks_per_emb),
783778
input=input,
784-
).view(-1,)
779+
).view(
780+
-1,
781+
)
785782
assert output_absmax.shape == (input.numel() * blocks_per_emb,)
786783

787784
output_quant_state = copy.deepcopy(self.weight.quant_state)
788785
output_quant_state.absmax = output_absmax
789786
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))
790787

791-
output = bnb.functional.dequantize_4bit(
792-
output_4bit, output_quant_state
793-
)
788+
output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state)
794789
assert output.shape == (*input.shape, self.embedding_dim)
795790

796791
return output.to(self.dtype)
@@ -803,10 +798,8 @@ def forward(self, input: Tensor) -> Tensor:
803798

804799
if self.embedding_dim % self.weight.quant_state.blocksize == 0:
805800
return self._forward_with_partial_dequantize(input)
806-
807-
dequantized_weight = bnb.functional.dequantize_4bit(
808-
self.weight.data, self.weight.quant_state
809-
)
801+
802+
dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state)
810803

811804
return torch.nn.functional.embedding(
812805
weight=dequantized_weight,
@@ -824,13 +817,13 @@ def __init__(
824817
device=None,
825818
):
826819
super().__init__(
827-
num_embeddings,
828-
embedding_dim,
829-
dtype=dtype,
830-
quant_type="fp4",
831-
quant_storage=quant_storage,
832-
device=device,
833-
)
820+
num_embeddings,
821+
embedding_dim,
822+
dtype=dtype,
823+
quant_type="fp4",
824+
quant_storage=quant_storage,
825+
device=device,
826+
)
834827

835828

836829
class EmbeddingNF4(Embedding4bit):
@@ -843,13 +836,13 @@ def __init__(
843836
device=None,
844837
):
845838
super().__init__(
846-
num_embeddings,
847-
embedding_dim,
848-
dtype=dtype,
849-
quant_type="nf4",
850-
quant_storage=quant_storage,
851-
device=device,
852-
)
839+
num_embeddings,
840+
embedding_dim,
841+
dtype=dtype,
842+
quant_type="nf4",
843+
quant_storage=quant_storage,
844+
device=device,
845+
)
853846

854847

855848
class Linear8bitLt(nn.Linear):

tests/test_modules.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import inspect
12
import math
23

34
import einops
45
import pytest
56
import torch
6-
import inspect
77
from torch import nn
88

99
import bitsandbytes as bnb
@@ -620,7 +620,8 @@ def test_fp8linear():
620620
@pytest.mark.parametrize("embedding_dim", [64, 65])
621621
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
622622
@pytest.mark.parametrize(
623-
"embedding_class,quant_storage", [
623+
"embedding_class,quant_storage",
624+
[
624625
(bnb.nn.Embedding8bit, None),
625626
(bnb.nn.EmbeddingFP4, torch.uint8),
626627
(bnb.nn.EmbeddingFP4, torch.float32),
@@ -632,9 +633,9 @@ def test_fp8linear():
632633
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage):
633634
num_embeddings = 128
634635

635-
src_weight = (
636-
(torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(torch.float32) * 2 - 1
637-
) # Embeddings filled with {-1, 1} values. It should compress losslessly
636+
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
637+
torch.float32
638+
) * 2 - 1 # Embeddings filled with {-1, 1} values. It should compress losslessly
638639

639640
emb_base = nn.Embedding(
640641
num_embeddings=num_embeddings,
@@ -652,7 +653,7 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
652653
emb_base.cuda()
653654
e.cuda()
654655

655-
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device='cuda')
656+
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
656657

657658
torch.testing.assert_close(
658659
actual=e(input_tokens),
@@ -663,7 +664,8 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
663664
@pytest.mark.parametrize("embedding_dim", [64, 65])
664665
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
665666
@pytest.mark.parametrize(
666-
"embedding_class,quant_storage", [
667+
"embedding_class,quant_storage",
668+
[
667669
(bnb.nn.Embedding8bit, None),
668670
(bnb.nn.EmbeddingFP4, torch.uint8),
669671
(bnb.nn.EmbeddingFP4, torch.float32),
@@ -695,7 +697,7 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
695697
emb_base.cuda()
696698
e.cuda()
697699

698-
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device='cuda')
700+
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
699701

700702
torch.testing.assert_close(
701703
actual=e(input_tokens),
@@ -740,7 +742,7 @@ def test_4bit_embedding_warnings():
740742
with pytest.warns(UserWarning, match=r"inference."):
741743
net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1)
742744
net.cuda()
743-
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device='cuda')
745+
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
744746
net(inp)
745747

746748

@@ -752,9 +754,9 @@ def test_4bit_embedding_weight_fsdp_fix():
752754

753755
module.cuda()
754756

755-
setattr(module.weight, "quant_state", None)
757+
module.weight.quant_state = None
756758

757-
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device='cuda')
759+
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
758760

759761
module(input_tokens)
760762

@@ -769,9 +771,9 @@ def test_4bit_linear_weight_fsdp_fix():
769771

770772
module.cuda()
771773

772-
setattr(module.weight, "quant_state", None)
774+
module.weight.quant_state = None
773775

774-
input_tensor = torch.randn((1, inp_size), device='cuda')
776+
input_tensor = torch.randn((1, inp_size), device="cuda")
775777

776778
module(input_tensor)
777779

0 commit comments

Comments
 (0)