@@ -1471,7 +1471,8 @@ __global__ void kgemm_4bit_inference_naive(
14711471 }
14721472 __syncthreads ();
14731473
1474- if (row_B >= M) return ;
1474+ if (row_B >= M)
1475+ return ;
14751476
14761477 const int stride = BNB_WARP_SIZE * num_values_4bit;
14771478 const int clz_blocksize = 31 - __clz (blocksize);
@@ -1518,16 +1519,16 @@ __global__ void kgemm_4bit_inference_naive(
15181519 prefetch_B = reinterpret_cast <int4 *>(B)[(offset_B + next_inner_idx_halved) / num_values_8bit];
15191520 }
15201521
1521- float b0 = quant_map[qm_offset + (local_B_4bit[0 ] >> 4 )] * local_absmax;
1522- float b1 = quant_map[qm_offset + (local_B_4bit[0 ] & 0xF )] * local_absmax;
1523- float b2 = quant_map[qm_offset + (local_B_4bit[1 ] >> 4 )] * local_absmax;
1524- float b3 = quant_map[qm_offset + (local_B_4bit[1 ] & 0xF )] * local_absmax;
1525- float b4 = quant_map[qm_offset + (local_B_4bit[2 ] >> 4 )] * local_absmax;
1526- float b5 = quant_map[qm_offset + (local_B_4bit[2 ] & 0xF )] * local_absmax;
1527- float b6 = quant_map[qm_offset + (local_B_4bit[3 ] >> 4 )] * local_absmax;
1528- float b7 = quant_map[qm_offset + (local_B_4bit[3 ] & 0xF )] * local_absmax;
1529- float b8 = quant_map[qm_offset + (local_B_4bit[4 ] >> 4 )] * local_absmax;
1530- float b9 = quant_map[qm_offset + (local_B_4bit[4 ] & 0xF )] * local_absmax;
1522+ float b0 = quant_map[qm_offset + (local_B_4bit[0 ] >> 4 )] * local_absmax;
1523+ float b1 = quant_map[qm_offset + (local_B_4bit[0 ] & 0xF )] * local_absmax;
1524+ float b2 = quant_map[qm_offset + (local_B_4bit[1 ] >> 4 )] * local_absmax;
1525+ float b3 = quant_map[qm_offset + (local_B_4bit[1 ] & 0xF )] * local_absmax;
1526+ float b4 = quant_map[qm_offset + (local_B_4bit[2 ] >> 4 )] * local_absmax;
1527+ float b5 = quant_map[qm_offset + (local_B_4bit[2 ] & 0xF )] * local_absmax;
1528+ float b6 = quant_map[qm_offset + (local_B_4bit[3 ] >> 4 )] * local_absmax;
1529+ float b7 = quant_map[qm_offset + (local_B_4bit[3 ] & 0xF )] * local_absmax;
1530+ float b8 = quant_map[qm_offset + (local_B_4bit[4 ] >> 4 )] * local_absmax;
1531+ float b9 = quant_map[qm_offset + (local_B_4bit[4 ] & 0xF )] * local_absmax;
15311532 float b10 = quant_map[qm_offset + (local_B_4bit[5 ] >> 4 )] * local_absmax;
15321533 float b11 = quant_map[qm_offset + (local_B_4bit[5 ] & 0xF )] * local_absmax;
15331534 float b12 = quant_map[qm_offset + (local_B_4bit[6 ] >> 4 )] * local_absmax;
@@ -1562,25 +1563,41 @@ __global__ void kgemm_4bit_inference_naive(
15621563 const T* a2 = reinterpret_cast <const T*>(&a_vec2);
15631564 const T* a3 = reinterpret_cast <const T*>(&a_vec3);
15641565
1565- local_C0 += (float )a0[0 ]*b0; local_C1 += (float )a0[1 ]*b1;
1566- local_C2 += (float )a0[2 ]*b2; local_C3 += (float )a0[3 ]*b3;
1567- local_C0 += (float )a0[4 ]*b4; local_C1 += (float )a0[5 ]*b5;
1568- local_C2 += (float )a0[6 ]*b6; local_C3 += (float )a0[7 ]*b7;
1569- local_C0 += (float )a1[0 ]*b8; local_C1 += (float )a1[1 ]*b9;
1570- local_C2 += (float )a1[2 ]*b10; local_C3 += (float )a1[3 ]*b11;
1571- local_C0 += (float )a1[4 ]*b12; local_C1 += (float )a1[5 ]*b13;
1572- local_C2 += (float )a1[6 ]*b14; local_C3 += (float )a1[7 ]*b15;
1573- local_C0 += (float )a2[0 ]*b16; local_C1 += (float )a2[1 ]*b17;
1574- local_C2 += (float )a2[2 ]*b18; local_C3 += (float )a2[3 ]*b19;
1575- local_C0 += (float )a2[4 ]*b20; local_C1 += (float )a2[5 ]*b21;
1576- local_C2 += (float )a2[6 ]*b22; local_C3 += (float )a2[7 ]*b23;
1577- local_C0 += (float )a3[0 ]*b24; local_C1 += (float )a3[1 ]*b25;
1578- local_C2 += (float )a3[2 ]*b26; local_C3 += (float )a3[3 ]*b27;
1579- local_C0 += (float )a3[4 ]*b28; local_C1 += (float )a3[5 ]*b29;
1580- local_C2 += (float )a3[6 ]*b30; local_C3 += (float )a3[7 ]*b31;
1566+ local_C0 += (float )a0[0 ] * b0;
1567+ local_C1 += (float )a0[1 ] * b1;
1568+ local_C2 += (float )a0[2 ] * b2;
1569+ local_C3 += (float )a0[3 ] * b3;
1570+ local_C0 += (float )a0[4 ] * b4;
1571+ local_C1 += (float )a0[5 ] * b5;
1572+ local_C2 += (float )a0[6 ] * b6;
1573+ local_C3 += (float )a0[7 ] * b7;
1574+ local_C0 += (float )a1[0 ] * b8;
1575+ local_C1 += (float )a1[1 ] * b9;
1576+ local_C2 += (float )a1[2 ] * b10;
1577+ local_C3 += (float )a1[3 ] * b11;
1578+ local_C0 += (float )a1[4 ] * b12;
1579+ local_C1 += (float )a1[5 ] * b13;
1580+ local_C2 += (float )a1[6 ] * b14;
1581+ local_C3 += (float )a1[7 ] * b15;
1582+ local_C0 += (float )a2[0 ] * b16;
1583+ local_C1 += (float )a2[1 ] * b17;
1584+ local_C2 += (float )a2[2 ] * b18;
1585+ local_C3 += (float )a2[3 ] * b19;
1586+ local_C0 += (float )a2[4 ] * b20;
1587+ local_C1 += (float )a2[5 ] * b21;
1588+ local_C2 += (float )a2[6 ] * b22;
1589+ local_C3 += (float )a2[7 ] * b23;
1590+ local_C0 += (float )a3[0 ] * b24;
1591+ local_C1 += (float )a3[1 ] * b25;
1592+ local_C2 += (float )a3[2 ] * b26;
1593+ local_C3 += (float )a3[3 ] * b27;
1594+ local_C0 += (float )a3[4 ] * b28;
1595+ local_C1 += (float )a3[5 ] * b29;
1596+ local_C2 += (float )a3[6 ] * b30;
1597+ local_C3 += (float )a3[7 ] * b31;
15811598 } else {
1582- float b_vals[32 ] = {b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15,
1583- b16,b17,b18,b19,b20,b21,b22,b23,b24,b25,b26,b27,b28,b29,b30,b31};
1599+ float b_vals[32 ] = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15,
1600+ b16, b17, b18, b19, b20, b21, b22, b23, b24, b25, b26, b27, b28, b29, b30, b31};
15841601#pragma unroll
15851602 for (int k = 0 ; k < 32 ; k++) {
15861603 float a_val = (inner_idx + k < K) ? (float )A_row[inner_idx + k] : 0 .0f ;
0 commit comments