Skip to content

Commit 811aa6c

Browse files
committed
lint
1 parent a1c7c61 commit 811aa6c

3 files changed

Lines changed: 45 additions & 50 deletions

File tree

bitsandbytes/nn/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
# LICENSE file in the root directory of this source tree.
55
from .modules import (
66
Embedding,
7+
Embedding4bit,
8+
Embedding8bit,
9+
EmbeddingFP4,
10+
EmbeddingNF4,
711
Int8Params,
812
Linear4bit,
913
Linear8bitLt,
1014
LinearFP4,
1115
LinearNF4,
12-
Embedding8bit,
13-
Embedding4bit,
14-
EmbeddingFP4,
15-
EmbeddingNF4,
1616
OutlierAwareLinear,
1717
Params4bit,
1818
StableEmbedding,

bitsandbytes/nn/modules.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
467467

468468
def forward(self, x: torch.Tensor):
469469
fix_4bit_weight_quant_state_from_module(self)
470-
470+
471471
# weights are cast automatically as Int8Params, but the bias has to be cast manually
472472
if self.bias is not None and self.bias.dtype != x.dtype:
473473
self.bias.data = self.bias.data.to(x.dtype)
@@ -684,6 +684,7 @@ class Embedding8bit(nn.Embedding):
684684
int8_module = int8_module.to(0) # Quantization happens here
685685
```
686686
"""
687+
687688
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
688689
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
689690
self.dtype = self.weight.data.dtype
@@ -694,10 +695,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
694695
raise NotImplementedError("saving Embedding4bit module is not implemented")
695696

696697
def forward(self, input: Tensor) -> Tensor:
697-
if not hasattr(self.weight, 'SCB'):
698-
raise RuntimeError(
699-
"Embedding layer is not quantized. Please call .cuda() or .to(device) first."
700-
)
698+
if not hasattr(self.weight, "SCB"):
699+
raise RuntimeError("Embedding layer is not quantized. Please call .cuda() or .to(device) first.")
701700

702701
rows = self.weight.data
703702
row_stats = self.weight.SCB
@@ -733,6 +732,7 @@ class Embedding4bit(nn.Embedding):
733732
quantized_module = quantized_module.to(0) # Quantization happens here
734733
```
735734
"""
735+
736736
def __init__(
737737
self,
738738
num_embeddings,
@@ -762,22 +762,17 @@ def __init__(
762762
"This will lead to slow inference.",
763763
)
764764

765-
766765
def _forward_with_partial_dequantize(self, input: Tensor):
767766
assert self.embedding_dim % self.weight.quant_state.blocksize == 0
768767

769-
w_4bit_uint8 = (
770-
self.weight.data.view(torch.uint8)
771-
.view(self.num_embeddings * self.embedding_dim // 2, 1)
772-
)
768+
w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1)
773769

774770
output_4bit = torch.nn.functional.embedding(
775771
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
776772
input=input,
777773
).view(-1, 1)
778774
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)
779775

780-
781776
blocks_per_emb = self.embedding_dim // self.weight.blocksize
782777

783778
absmax = self.weight.quant_state.absmax
@@ -786,16 +781,16 @@ def _forward_with_partial_dequantize(self, input: Tensor):
786781
output_absmax = torch.nn.functional.embedding(
787782
weight=absmax.view(self.num_embeddings, blocks_per_emb),
788783
input=input,
789-
).view(-1,)
784+
).view(
785+
-1,
786+
)
790787
assert output_absmax.shape == (input.numel() * blocks_per_emb,)
791788

792789
output_quant_state = copy.deepcopy(self.weight.quant_state)
793790
output_quant_state.absmax = output_absmax
794791
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))
795792

796-
output = bnb.functional.dequantize_4bit(
797-
output_4bit, output_quant_state
798-
)
793+
output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state)
799794
assert output.shape == (*input.shape, self.embedding_dim)
800795

801796
return output.to(self.dtype)
@@ -808,10 +803,8 @@ def forward(self, input: Tensor) -> Tensor:
808803

809804
if self.embedding_dim % self.weight.quant_state.blocksize == 0:
810805
return self._forward_with_partial_dequantize(input)
811-
812-
dequantized_weight = bnb.functional.dequantize_4bit(
813-
self.weight.data, self.weight.quant_state
814-
)
806+
807+
dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state)
815808

816809
return torch.nn.functional.embedding(
817810
weight=dequantized_weight,
@@ -829,13 +822,13 @@ def __init__(
829822
device=None,
830823
):
831824
super().__init__(
832-
num_embeddings,
833-
embedding_dim,
834-
dtype=dtype,
835-
quant_type="fp4",
836-
quant_storage=quant_storage,
837-
device=device,
838-
)
825+
num_embeddings,
826+
embedding_dim,
827+
dtype=dtype,
828+
quant_type="fp4",
829+
quant_storage=quant_storage,
830+
device=device,
831+
)
839832

840833

841834
class EmbeddingNF4(Embedding4bit):
@@ -848,13 +841,13 @@ def __init__(
848841
device=None,
849842
):
850843
super().__init__(
851-
num_embeddings,
852-
embedding_dim,
853-
dtype=dtype,
854-
quant_type="nf4",
855-
quant_storage=quant_storage,
856-
device=device,
857-
)
844+
num_embeddings,
845+
embedding_dim,
846+
dtype=dtype,
847+
quant_type="nf4",
848+
quant_storage=quant_storage,
849+
device=device,
850+
)
858851

859852

860853
class Linear8bitLt(nn.Linear):

tests/test_modules.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import inspect
12
import math
23

34
import einops
45
import pytest
56
import torch
6-
import inspect
77
from torch import nn
88

99
import bitsandbytes as bnb
@@ -620,7 +620,8 @@ def test_fp8linear():
620620
@pytest.mark.parametrize("embedding_dim", [64, 65])
621621
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
622622
@pytest.mark.parametrize(
623-
"embedding_class,quant_storage", [
623+
"embedding_class,quant_storage",
624+
[
624625
(bnb.nn.Embedding8bit, None),
625626
(bnb.nn.EmbeddingFP4, torch.uint8),
626627
(bnb.nn.EmbeddingFP4, torch.float32),
@@ -632,9 +633,9 @@ def test_fp8linear():
632633
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage):
633634
num_embeddings = 128
634635

635-
src_weight = (
636-
(torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(torch.float32) * 2 - 1
637-
) # Embeddings filled with {-1, 1} values. It should compress losslessly
636+
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
637+
torch.float32
638+
) * 2 - 1 # Embeddings filled with {-1, 1} values. It should compress losslessly
638639

639640
emb_base = nn.Embedding(
640641
num_embeddings=num_embeddings,
@@ -652,7 +653,7 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
652653
emb_base.cuda()
653654
e.cuda()
654655

655-
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device='cuda')
656+
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
656657

657658
torch.testing.assert_close(
658659
actual=e(input_tokens),
@@ -663,7 +664,8 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
663664
@pytest.mark.parametrize("embedding_dim", [64, 65])
664665
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
665666
@pytest.mark.parametrize(
666-
"embedding_class,quant_storage", [
667+
"embedding_class,quant_storage",
668+
[
667669
(bnb.nn.Embedding8bit, None),
668670
(bnb.nn.EmbeddingFP4, torch.uint8),
669671
(bnb.nn.EmbeddingFP4, torch.float32),
@@ -695,7 +697,7 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
695697
emb_base.cuda()
696698
e.cuda()
697699

698-
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device='cuda')
700+
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
699701

700702
torch.testing.assert_close(
701703
actual=e(input_tokens),
@@ -740,7 +742,7 @@ def test_4bit_embedding_warnings():
740742
with pytest.warns(UserWarning, match=r"inference."):
741743
net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1)
742744
net.cuda()
743-
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device='cuda')
745+
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
744746
net(inp)
745747

746748

@@ -752,9 +754,9 @@ def test_4bit_embedding_weight_fsdp_fix():
752754

753755
module.cuda()
754756

755-
setattr(module.weight, "quant_state", None)
757+
module.weight.quant_state = None
756758

757-
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device='cuda')
759+
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
758760

759761
module(input_tokens)
760762

@@ -769,9 +771,9 @@ def test_4bit_linear_weight_fsdp_fix():
769771

770772
module.cuda()
771773

772-
setattr(module.weight, "quant_state", None)
774+
module.weight.quant_state = None
773775

774-
input_tensor = torch.randn((1, inp_size), device='cuda')
776+
input_tensor = torch.randn((1, inp_size), device="cuda")
775777

776778
module(input_tensor)
777779

0 commit comments

Comments
 (0)