my code:
torch.manual_seed(42)
1. Creat a real-valued Linear layer
linear_fp = torch.nn.Linear(8, 4).cuda().half()
2. modify the weight
add some noise
with torch.no_grad():
linear_fp.weight[:, 1] += 0.2 * torch.randn(4).cuda().half()
linear_fp.weight[:, 4] += 0.2 * torch.randn(4).cuda().half()
print("Modified weight:", linear_fp.weight)
3. quantization:set threshold to be small
linear_int8 = bnb.nn.Linear8bitLt(
8, 4,
has_fp16_weights=Flase,
threshold=0.1
)
linear_int8.load_state_dict(linear_fp.state_dict())
linear_int8=linear_int8.to(0)
4. Forward
x = torch.randn(2, 8).cuda().half()
with torch.no_grad():
_ = linear_int8(x)
5. watch state.idx
print("State.idx:", linear_int8.state.idx.cpu().numpy())
But it always return [0,1,2,3,4,5,6,7], I think some columns with value > 0.1 should not be quantized. Did I misunderstand something?
my code:
torch.manual_seed(42)
1. Creat a real-valued Linear layer
linear_fp = torch.nn.Linear(8, 4).cuda().half()
2. modify the weight
add some noise
with torch.no_grad():
linear_fp.weight[:, 1] += 0.2 * torch.randn(4).cuda().half()
linear_fp.weight[:, 4] += 0.2 * torch.randn(4).cuda().half()
print("Modified weight:", linear_fp.weight)
3. quantization:set threshold to be small
linear_int8 = bnb.nn.Linear8bitLt(
8, 4,
has_fp16_weights=Flase,
threshold=0.1
)
linear_int8.load_state_dict(linear_fp.state_dict())
linear_int8=linear_int8.to(0)
4. Forward
x = torch.randn(2, 8).cuda().half()
with torch.no_grad():
_ = linear_int8(x)
5. watch state.idx
print("State.idx:", linear_int8.state.idx.cpu().numpy())
But it always return [0,1,2,3,4,5,6,7], I think some columns with value > 0.1 should not be quantized. Did I misunderstand something?