1+ import inspect
12import math
23
34import einops
45import pytest
56import torch
6- import inspect
77from torch import nn
88
99import bitsandbytes as bnb
@@ -620,7 +620,8 @@ def test_fp8linear():
620620@pytest .mark .parametrize ("embedding_dim" , [64 , 65 ])
621621@pytest .mark .parametrize ("input_shape" , [(10 ,), (10 , 10 ), (10 , 10 , 10 )], ids = str )
622622@pytest .mark .parametrize (
623- "embedding_class,quant_storage" , [
623+ "embedding_class,quant_storage" ,
624+ [
624625 (bnb .nn .Embedding8bit , None ),
625626 (bnb .nn .EmbeddingFP4 , torch .uint8 ),
626627 (bnb .nn .EmbeddingFP4 , torch .float32 ),
@@ -632,9 +633,9 @@ def test_fp8linear():
632633def test_embedding_lossless (embedding_class , input_shape , embedding_dim , quant_storage ):
633634 num_embeddings = 128
634635
635- src_weight = (
636- ( torch .randn (( num_embeddings , embedding_dim ), dtype = torch . float32 ) > 0 ). to ( torch . float32 ) * 2 - 1
637- ) # Embeddings filled with {-1, 1} values. It should compress losslessly
636+ src_weight = (torch . randn (( num_embeddings , embedding_dim ), dtype = torch . float32 ) > 0 ). to (
637+ torch .float32
638+ ) * 2 - 1 # Embeddings filled with {-1, 1} values. It should compress losslessly
638639
639640 emb_base = nn .Embedding (
640641 num_embeddings = num_embeddings ,
@@ -652,7 +653,7 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
652653 emb_base .cuda ()
653654 e .cuda ()
654655
655- input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = ' cuda' )
656+ input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = " cuda" )
656657
657658 torch .testing .assert_close (
658659 actual = e (input_tokens ),
@@ -663,7 +664,8 @@ def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_s
663664@pytest .mark .parametrize ("embedding_dim" , [64 , 65 ])
664665@pytest .mark .parametrize ("input_shape" , [(10 ,), (10 , 10 ), (10 , 10 , 10 )], ids = str )
665666@pytest .mark .parametrize (
666- "embedding_class,quant_storage" , [
667+ "embedding_class,quant_storage" ,
668+ [
667669 (bnb .nn .Embedding8bit , None ),
668670 (bnb .nn .EmbeddingFP4 , torch .uint8 ),
669671 (bnb .nn .EmbeddingFP4 , torch .float32 ),
@@ -695,7 +697,7 @@ def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_stor
695697 emb_base .cuda ()
696698 e .cuda ()
697699
698- input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = ' cuda' )
700+ input_tokens = torch .randint (low = 0 , high = num_embeddings , size = input_shape , device = " cuda" )
699701
700702 torch .testing .assert_close (
701703 actual = e (input_tokens ),
@@ -740,7 +742,7 @@ def test_4bit_embedding_warnings():
740742 with pytest .warns (UserWarning , match = r"inference." ):
741743 net = bnb .nn .Embedding4bit (num_embeddings = num_embeddings , embedding_dim = default_block_size + 1 )
742744 net .cuda ()
743- inp = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = ' cuda' )
745+ inp = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = " cuda" )
744746 net (inp )
745747
746748
@@ -752,9 +754,9 @@ def test_4bit_embedding_weight_fsdp_fix():
752754
753755 module .cuda ()
754756
755- setattr ( module .weight , " quant_state" , None )
757+ module .weight . quant_state = None
756758
757- input_tokens = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = ' cuda' )
759+ input_tokens = torch .randint (low = 0 , high = num_embeddings , size = (1 ,), device = " cuda" )
758760
759761 module (input_tokens )
760762
@@ -769,9 +771,9 @@ def test_4bit_linear_weight_fsdp_fix():
769771
770772 module .cuda ()
771773
772- setattr ( module .weight , " quant_state" , None )
774+ module .weight . quant_state = None
773775
774- input_tensor = torch .randn ((1 , inp_size ), device = ' cuda' )
776+ input_tensor = torch .randn ((1 , inp_size ), device = " cuda" )
775777
776778 module (input_tensor )
777779
0 commit comments