Skip to content

Commit 8bf2856

Browse files
committed
Fix pre-commit and empty input fallback
1 parent 6aeade4 commit 8bf2856

2 files changed

Lines changed: 47 additions & 30 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def matmul_4bit(
396396
return MatMul4Bit.apply(A, B, out, bias, quant_state)
397397

398398
num_a_rows = A.numel() // A.shape[-1]
399-
if num_a_rows <= FUSED_4BIT_DEQUANT_LIMIT and A.requires_grad == False and A.device.type != "hpu":
399+
if 0 < num_a_rows <= FUSED_4BIT_DEQUANT_LIMIT and A.requires_grad == False and A.device.type != "hpu":
400400
if A.shape[-1] % quant_state.blocksize != 0:
401401
warn(
402402
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

csrc/kernels.cu

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,7 +1471,8 @@ __global__ void kgemm_4bit_inference_naive(
14711471
}
14721472
__syncthreads();
14731473

1474-
if (row_B >= M) return;
1474+
if (row_B >= M)
1475+
return;
14751476

14761477
const int stride = BNB_WARP_SIZE * num_values_4bit;
14771478
const int clz_blocksize = 31 - __clz(blocksize);
@@ -1518,16 +1519,16 @@ __global__ void kgemm_4bit_inference_naive(
15181519
prefetch_B = reinterpret_cast<int4*>(B)[(offset_B + next_inner_idx_halved) / num_values_8bit];
15191520
}
15201521

1521-
float b0 = quant_map[qm_offset + (local_B_4bit[0] >> 4)] * local_absmax;
1522-
float b1 = quant_map[qm_offset + (local_B_4bit[0] & 0xF)] * local_absmax;
1523-
float b2 = quant_map[qm_offset + (local_B_4bit[1] >> 4)] * local_absmax;
1524-
float b3 = quant_map[qm_offset + (local_B_4bit[1] & 0xF)] * local_absmax;
1525-
float b4 = quant_map[qm_offset + (local_B_4bit[2] >> 4)] * local_absmax;
1526-
float b5 = quant_map[qm_offset + (local_B_4bit[2] & 0xF)] * local_absmax;
1527-
float b6 = quant_map[qm_offset + (local_B_4bit[3] >> 4)] * local_absmax;
1528-
float b7 = quant_map[qm_offset + (local_B_4bit[3] & 0xF)] * local_absmax;
1529-
float b8 = quant_map[qm_offset + (local_B_4bit[4] >> 4)] * local_absmax;
1530-
float b9 = quant_map[qm_offset + (local_B_4bit[4] & 0xF)] * local_absmax;
1522+
float b0 = quant_map[qm_offset + (local_B_4bit[0] >> 4)] * local_absmax;
1523+
float b1 = quant_map[qm_offset + (local_B_4bit[0] & 0xF)] * local_absmax;
1524+
float b2 = quant_map[qm_offset + (local_B_4bit[1] >> 4)] * local_absmax;
1525+
float b3 = quant_map[qm_offset + (local_B_4bit[1] & 0xF)] * local_absmax;
1526+
float b4 = quant_map[qm_offset + (local_B_4bit[2] >> 4)] * local_absmax;
1527+
float b5 = quant_map[qm_offset + (local_B_4bit[2] & 0xF)] * local_absmax;
1528+
float b6 = quant_map[qm_offset + (local_B_4bit[3] >> 4)] * local_absmax;
1529+
float b7 = quant_map[qm_offset + (local_B_4bit[3] & 0xF)] * local_absmax;
1530+
float b8 = quant_map[qm_offset + (local_B_4bit[4] >> 4)] * local_absmax;
1531+
float b9 = quant_map[qm_offset + (local_B_4bit[4] & 0xF)] * local_absmax;
15311532
float b10 = quant_map[qm_offset + (local_B_4bit[5] >> 4)] * local_absmax;
15321533
float b11 = quant_map[qm_offset + (local_B_4bit[5] & 0xF)] * local_absmax;
15331534
float b12 = quant_map[qm_offset + (local_B_4bit[6] >> 4)] * local_absmax;
@@ -1562,25 +1563,41 @@ __global__ void kgemm_4bit_inference_naive(
15621563
const T* a2 = reinterpret_cast<const T*>(&a_vec2);
15631564
const T* a3 = reinterpret_cast<const T*>(&a_vec3);
15641565

1565-
local_C0 += (float)a0[0]*b0; local_C1 += (float)a0[1]*b1;
1566-
local_C2 += (float)a0[2]*b2; local_C3 += (float)a0[3]*b3;
1567-
local_C0 += (float)a0[4]*b4; local_C1 += (float)a0[5]*b5;
1568-
local_C2 += (float)a0[6]*b6; local_C3 += (float)a0[7]*b7;
1569-
local_C0 += (float)a1[0]*b8; local_C1 += (float)a1[1]*b9;
1570-
local_C2 += (float)a1[2]*b10; local_C3 += (float)a1[3]*b11;
1571-
local_C0 += (float)a1[4]*b12; local_C1 += (float)a1[5]*b13;
1572-
local_C2 += (float)a1[6]*b14; local_C3 += (float)a1[7]*b15;
1573-
local_C0 += (float)a2[0]*b16; local_C1 += (float)a2[1]*b17;
1574-
local_C2 += (float)a2[2]*b18; local_C3 += (float)a2[3]*b19;
1575-
local_C0 += (float)a2[4]*b20; local_C1 += (float)a2[5]*b21;
1576-
local_C2 += (float)a2[6]*b22; local_C3 += (float)a2[7]*b23;
1577-
local_C0 += (float)a3[0]*b24; local_C1 += (float)a3[1]*b25;
1578-
local_C2 += (float)a3[2]*b26; local_C3 += (float)a3[3]*b27;
1579-
local_C0 += (float)a3[4]*b28; local_C1 += (float)a3[5]*b29;
1580-
local_C2 += (float)a3[6]*b30; local_C3 += (float)a3[7]*b31;
1566+
local_C0 += (float)a0[0] * b0;
1567+
local_C1 += (float)a0[1] * b1;
1568+
local_C2 += (float)a0[2] * b2;
1569+
local_C3 += (float)a0[3] * b3;
1570+
local_C0 += (float)a0[4] * b4;
1571+
local_C1 += (float)a0[5] * b5;
1572+
local_C2 += (float)a0[6] * b6;
1573+
local_C3 += (float)a0[7] * b7;
1574+
local_C0 += (float)a1[0] * b8;
1575+
local_C1 += (float)a1[1] * b9;
1576+
local_C2 += (float)a1[2] * b10;
1577+
local_C3 += (float)a1[3] * b11;
1578+
local_C0 += (float)a1[4] * b12;
1579+
local_C1 += (float)a1[5] * b13;
1580+
local_C2 += (float)a1[6] * b14;
1581+
local_C3 += (float)a1[7] * b15;
1582+
local_C0 += (float)a2[0] * b16;
1583+
local_C1 += (float)a2[1] * b17;
1584+
local_C2 += (float)a2[2] * b18;
1585+
local_C3 += (float)a2[3] * b19;
1586+
local_C0 += (float)a2[4] * b20;
1587+
local_C1 += (float)a2[5] * b21;
1588+
local_C2 += (float)a2[6] * b22;
1589+
local_C3 += (float)a2[7] * b23;
1590+
local_C0 += (float)a3[0] * b24;
1591+
local_C1 += (float)a3[1] * b25;
1592+
local_C2 += (float)a3[2] * b26;
1593+
local_C3 += (float)a3[3] * b27;
1594+
local_C0 += (float)a3[4] * b28;
1595+
local_C1 += (float)a3[5] * b29;
1596+
local_C2 += (float)a3[6] * b30;
1597+
local_C3 += (float)a3[7] * b31;
15811598
} else {
1582-
float b_vals[32] = {b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15,
1583-
b16,b17,b18,b19,b20,b21,b22,b23,b24,b25,b26,b27,b28,b29,b30,b31};
1599+
float b_vals[32] = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15,
1600+
b16, b17, b18, b19, b20, b21, b22, b23, b24, b25, b26, b27, b28, b29, b30, b31};
15841601
#pragma unroll
15851602
for (int k = 0; k < 32; k++) {
15861603
float a_val = (inner_idx + k < K) ? (float)A_row[inner_idx + k] : 0.0f;

0 commit comments

Comments
 (0)