@@ -1446,116 +1446,155 @@ __global__ void kgemm_4bit_inference_naive(
14461446 int M, int N, int K, T* __restrict__ const A, unsigned char * B, float * absmax, const float * datatype, T* out,
14471447 int lda, int ldb, int ldc, int blocksize
14481448) {
1449-
1450- // per threadblock:
1451- // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
1452- // THREADS/BNB_WARP_SIZE warps -> that many loads per iter
1453- // 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block
14541449 typedef bnb_cub::WarpReduce<float > WarpReduce;
14551450 __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE];
14561451
14571452 const int warp_idx = threadIdx .x / BNB_WARP_SIZE;
14581453 const int warp_lane = threadIdx .x % BNB_WARP_SIZE;
14591454 const int row_B = (THREADS / BNB_WARP_SIZE) * blockIdx .x + warp_idx;
14601455 const int offset_B = ldb * row_B;
1461- const int num_values_8bit = num_values_4bit / 2 ;
1462- float local_C = 0 .0f ;
1456+ constexpr int num_values_8bit = num_values_4bit / 2 ;
1457+
1458+ float local_C0 = 0 .0f ;
1459+ float local_C1 = 0 .0f ;
1460+ float local_C2 = 0 .0f ;
1461+ float local_C3 = 0 .0f ;
14631462
14641463 unsigned char local_B_4bit[num_values_8bit];
1465- T local_B[num_values_4bit / 4 ];
1466- T local_A[num_values_4bit / 4 ];
1467- __shared__ T quant_map[16 ];
1468- T local_absmax = T (0 .0f );
1469-
1470- if (threadIdx .x < 16 )
1471- quant_map[threadIdx .x ] = T (__ldg (&datatype[threadIdx .x ]));
1472- // for(int i = threadIdx.x; i < 16; i++)
1473- // quant_map[i] = T(__ldg(&datatype[i]));
1464+ __shared__ float quant_map[32 ];
1465+ float local_absmax = 0 .0f ;
1466+
1467+ if (threadIdx .x < 16 ) {
1468+ float val = __ldg (&datatype[threadIdx .x ]);
1469+ quant_map[threadIdx .x ] = val;
1470+ quant_map[threadIdx .x + 16 ] = val;
1471+ }
14741472 __syncthreads ();
14751473
1476- // A: [1, K]
1477- // B: [N, K]
1478- for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) {
1479- const int inner_idx_halved = inner_idx / 2 ;
1474+ if (row_B >= M) return ;
14801475
1481- // Since blocksize will always be a power-of-2, we avoid more expensive
1482- // division by the blocksize and instead use a shift operation.
1483- // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
1484- const int absidx = (( 2 * offset_B) + inner_idx) >> ( 31 - __clz (blocksize)) ;
1476+ const int stride = BNB_WARP_SIZE * num_values_4bit;
1477+ const int clz_blocksize = 31 - __clz ( blocksize);
1478+ const int base_absidx = 2 * offset_B;
1479+ const int qm_offset = (warp_lane & 1 ) << 4 ;
14851480
1486- local_absmax = __ldg (&(absmax[absidx]));
1481+ for (int n_idx = 0 ; n_idx < N; n_idx++) {
1482+ const T* __restrict__ A_row = A + n_idx * lda;
14871483
1488- if (row_B < M) {
1489- if ((inner_idx_halved + num_values_8bit) < (K / 2 )) {
1490- // this is the most important for performance considerations
1491- reinterpret_cast <int4 (&)[num_values_8bit]>(local_B_4bit)[0 ] =
1492- reinterpret_cast <int4 *>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
1493- } else {
1494- #pragma unroll
1495- for (int j = 0 ; j < (num_values_8bit); j++)
1496- if ((inner_idx_halved) + j < (K / 2 ))
1497- local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
1498- else
1499- local_B_4bit[j] = 0b01110111 ;
1500- }
1501- } else {
1502- #pragma unroll
1503- for (int j = 0 ; j < (num_values_8bit); j++)
1504- local_B_4bit[j] = 0b01110111 ;
1484+ local_C0 = 0 .0f ;
1485+ local_C1 = 0 .0f ;
1486+ local_C2 = 0 .0f ;
1487+ local_C3 = 0 .0f ;
1488+
1489+ int inner_idx = warp_lane * num_values_4bit;
1490+ int inner_idx_halved = inner_idx >> 1 ;
1491+ int4 prefetch_B;
1492+ float prefetch_absmax;
1493+
1494+ if (inner_idx < K) {
1495+ prefetch_absmax = __ldg (&absmax[(base_absidx + inner_idx) >> clz_blocksize]);
1496+ if ((inner_idx_halved + num_values_8bit) < (K >> 1 ))
1497+ prefetch_B = reinterpret_cast <int4 *>(B)[(offset_B + inner_idx_halved) / num_values_8bit];
15051498 }
15061499
1507- for (int i = 0 ; i < 4 ; i++) {
1508- #pragma unroll
1509- for (int k = 0 ; k < num_values_8bit / 4 ; k++) {
1510- #if BNB_BF16_AVAILABLE
1511- local_B[k * 2 ] = quant_map[local_B_4bit[(i * num_values_8bit / 4 ) + k] >> 4 ] * local_absmax;
1512- local_B[k * 2 + 1 ] = quant_map[local_B_4bit[(i * num_values_8bit / 4 ) + k] & 0x0F ] * local_absmax;
1513- #else
1514- // bf16 multipliation not supported
1515- local_B[k * 2 ] =
1516- T ((float )quant_map[local_B_4bit[(i * num_values_8bit / 4 ) + k] >> 4 ] * (float )local_absmax);
1517- local_B[k * 2 + 1 ] =
1518- T ((float )quant_map[local_B_4bit[(i * num_values_8bit / 4 ) + k] & 0x0F ] * (float )local_absmax);
1519- #endif
1520- }
1500+ for (; inner_idx < K; inner_idx += stride) {
1501+ inner_idx_halved = inner_idx >> 1 ;
15211502
1522- if (inner_idx + (num_values_4bit / 4 ) + (i * num_values_4bit / 4 ) < K) {
1523- // this is also relatively important for performance
1524- if (BITS == 16 ) {
1525- reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[0 ] =
1526- reinterpret_cast <int4 *>(A)[inner_idx / (num_values_4bit / 4 ) + i];
1527- } else {
1528- reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[0 ] =
1529- reinterpret_cast <int4 *>(A)[inner_idx / (num_values_4bit / 8 ) + (2 * i) + 0 ];
1530- reinterpret_cast <int4 (&)[num_values_4bit]>(local_A)[1 ] =
1531- reinterpret_cast <int4 *>(A)[inner_idx / (num_values_4bit / 8 ) + (2 * i) + 1 ];
1532- }
1503+ local_absmax = prefetch_absmax;
15331504
1534- } else
1505+ if (__builtin_expect ((inner_idx_halved + num_values_8bit) < (K >> 1 ), 1 )) {
1506+ reinterpret_cast <int4 &>(local_B_4bit[0 ]) = prefetch_B;
1507+ } else {
15351508#pragma unroll
1536- for (int k = 0 ; k < num_values_4bit / 4 ; k++)
1537- if (inner_idx + (i * num_values_4bit / 4 ) + k < K)
1538- local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4 )];
1539- else
1540- local_A[k] = T (0 .0f );
1509+ for (int j = 0 ; j < num_values_8bit; j++)
1510+ local_B_4bit[j] = ((inner_idx_halved + j) < (K >> 1 )) ? B[offset_B + inner_idx_halved + j] : 0x77 ;
1511+ }
1512+
1513+ int next_inner_idx = inner_idx + stride;
1514+ int next_inner_idx_halved = next_inner_idx >> 1 ;
1515+ if (next_inner_idx < K) {
1516+ prefetch_absmax = __ldg (&absmax[(base_absidx + next_inner_idx) >> clz_blocksize]);
1517+ if ((next_inner_idx_halved + num_values_8bit) < (K >> 1 ))
1518+ prefetch_B = reinterpret_cast <int4 *>(B)[(offset_B + next_inner_idx_halved) / num_values_8bit];
1519+ }
15411520
1542- // accumulate in float; small performance hit for Ampere, but lower error for outputs
1521+ float b0 = quant_map[qm_offset + (local_B_4bit[0 ] >> 4 )] * local_absmax;
1522+ float b1 = quant_map[qm_offset + (local_B_4bit[0 ] & 0xF )] * local_absmax;
1523+ float b2 = quant_map[qm_offset + (local_B_4bit[1 ] >> 4 )] * local_absmax;
1524+ float b3 = quant_map[qm_offset + (local_B_4bit[1 ] & 0xF )] * local_absmax;
1525+ float b4 = quant_map[qm_offset + (local_B_4bit[2 ] >> 4 )] * local_absmax;
1526+ float b5 = quant_map[qm_offset + (local_B_4bit[2 ] & 0xF )] * local_absmax;
1527+ float b6 = quant_map[qm_offset + (local_B_4bit[3 ] >> 4 )] * local_absmax;
1528+ float b7 = quant_map[qm_offset + (local_B_4bit[3 ] & 0xF )] * local_absmax;
1529+ float b8 = quant_map[qm_offset + (local_B_4bit[4 ] >> 4 )] * local_absmax;
1530+ float b9 = quant_map[qm_offset + (local_B_4bit[4 ] & 0xF )] * local_absmax;
1531+ float b10 = quant_map[qm_offset + (local_B_4bit[5 ] >> 4 )] * local_absmax;
1532+ float b11 = quant_map[qm_offset + (local_B_4bit[5 ] & 0xF )] * local_absmax;
1533+ float b12 = quant_map[qm_offset + (local_B_4bit[6 ] >> 4 )] * local_absmax;
1534+ float b13 = quant_map[qm_offset + (local_B_4bit[6 ] & 0xF )] * local_absmax;
1535+ float b14 = quant_map[qm_offset + (local_B_4bit[7 ] >> 4 )] * local_absmax;
1536+ float b15 = quant_map[qm_offset + (local_B_4bit[7 ] & 0xF )] * local_absmax;
1537+ float b16 = quant_map[qm_offset + (local_B_4bit[8 ] >> 4 )] * local_absmax;
1538+ float b17 = quant_map[qm_offset + (local_B_4bit[8 ] & 0xF )] * local_absmax;
1539+ float b18 = quant_map[qm_offset + (local_B_4bit[9 ] >> 4 )] * local_absmax;
1540+ float b19 = quant_map[qm_offset + (local_B_4bit[9 ] & 0xF )] * local_absmax;
1541+ float b20 = quant_map[qm_offset + (local_B_4bit[10 ] >> 4 )] * local_absmax;
1542+ float b21 = quant_map[qm_offset + (local_B_4bit[10 ] & 0xF )] * local_absmax;
1543+ float b22 = quant_map[qm_offset + (local_B_4bit[11 ] >> 4 )] * local_absmax;
1544+ float b23 = quant_map[qm_offset + (local_B_4bit[11 ] & 0xF )] * local_absmax;
1545+ float b24 = quant_map[qm_offset + (local_B_4bit[12 ] >> 4 )] * local_absmax;
1546+ float b25 = quant_map[qm_offset + (local_B_4bit[12 ] & 0xF )] * local_absmax;
1547+ float b26 = quant_map[qm_offset + (local_B_4bit[13 ] >> 4 )] * local_absmax;
1548+ float b27 = quant_map[qm_offset + (local_B_4bit[13 ] & 0xF )] * local_absmax;
1549+ float b28 = quant_map[qm_offset + (local_B_4bit[14 ] >> 4 )] * local_absmax;
1550+ float b29 = quant_map[qm_offset + (local_B_4bit[14 ] & 0xF )] * local_absmax;
1551+ float b30 = quant_map[qm_offset + (local_B_4bit[15 ] >> 4 )] * local_absmax;
1552+ float b31 = quant_map[qm_offset + (local_B_4bit[15 ] & 0xF )] * local_absmax;
1553+
1554+ if (__builtin_expect (inner_idx + 32 <= K, 1 )) {
1555+ int4 a_vec0 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 ];
1556+ int4 a_vec1 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 1 ];
1557+ int4 a_vec2 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 2 ];
1558+ int4 a_vec3 = reinterpret_cast <const int4 *>(A_row)[inner_idx / 8 + 3 ];
1559+
1560+ const T* a0 = reinterpret_cast <const T*>(&a_vec0);
1561+ const T* a1 = reinterpret_cast <const T*>(&a_vec1);
1562+ const T* a2 = reinterpret_cast <const T*>(&a_vec2);
1563+ const T* a3 = reinterpret_cast <const T*>(&a_vec3);
1564+
1565+ local_C0 += (float )a0[0 ]*b0; local_C1 += (float )a0[1 ]*b1;
1566+ local_C2 += (float )a0[2 ]*b2; local_C3 += (float )a0[3 ]*b3;
1567+ local_C0 += (float )a0[4 ]*b4; local_C1 += (float )a0[5 ]*b5;
1568+ local_C2 += (float )a0[6 ]*b6; local_C3 += (float )a0[7 ]*b7;
1569+ local_C0 += (float )a1[0 ]*b8; local_C1 += (float )a1[1 ]*b9;
1570+ local_C2 += (float )a1[2 ]*b10; local_C3 += (float )a1[3 ]*b11;
1571+ local_C0 += (float )a1[4 ]*b12; local_C1 += (float )a1[5 ]*b13;
1572+ local_C2 += (float )a1[6 ]*b14; local_C3 += (float )a1[7 ]*b15;
1573+ local_C0 += (float )a2[0 ]*b16; local_C1 += (float )a2[1 ]*b17;
1574+ local_C2 += (float )a2[2 ]*b18; local_C3 += (float )a2[3 ]*b19;
1575+ local_C0 += (float )a2[4 ]*b20; local_C1 += (float )a2[5 ]*b21;
1576+ local_C2 += (float )a2[6 ]*b22; local_C3 += (float )a2[7 ]*b23;
1577+ local_C0 += (float )a3[0 ]*b24; local_C1 += (float )a3[1 ]*b25;
1578+ local_C2 += (float )a3[2 ]*b26; local_C3 += (float )a3[3 ]*b27;
1579+ local_C0 += (float )a3[4 ]*b28; local_C1 += (float )a3[5 ]*b29;
1580+ local_C2 += (float )a3[6 ]*b30; local_C3 += (float )a3[7 ]*b31;
1581+ } else {
1582+ float b_vals[32 ] = {b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15,
1583+ b16,b17,b18,b19,b20,b21,b22,b23,b24,b25,b26,b27,b28,b29,b30,b31};
15431584#pragma unroll
1544- for (int k = 0 ; k < num_values_4bit / 4 ; k++) {
1545- #if BNB_BF16_AVAILABLE
1546- local_C += (float )(local_A[k] * local_B[k]);
1547- #else
1548- // bf16 multipliation not supported
1549- local_C += ((float )local_A[k] * (float )local_B[k]);
1550- #endif
1585+ for (int k = 0 ; k < 32 ; k++) {
1586+ float a_val = (inner_idx + k < K) ? (float )A_row[inner_idx + k] : 0 .0f ;
1587+ local_C0 += a_val * b_vals[k];
1588+ }
15511589 }
15521590 }
1553- }
15541591
1555- local_C = WarpReduce (temp_storage[warp_idx]).Sum (local_C);
1592+ float local_C = local_C0 + local_C1 + local_C2 + local_C3;
1593+ local_C = WarpReduce (temp_storage[warp_idx]).Sum (local_C);
15561594
1557- if (row_B < M && warp_lane == 0 )
1558- out[row_B] = T (local_C);
1595+ if (warp_lane == 0 )
1596+ out[n_idx * ldc + row_B] = T (local_C);
1597+ }
15591598}
15601599
15611600template <typename T, int FUNC> __global__ void kfunc (T* A, T* B, T value, long n) {
@@ -1595,6 +1634,18 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(
15951634 int M, int N, int K, float * __restrict__ const A, unsigned char * B, float * absmax, const float * datatype,
15961635 float * out, int lda, int ldb, int ldc, int blocksize
15971636);
1637+ template __global__ void kgemm_4bit_inference_naive<half, 64 , 16 >(
1638+ int M, int N, int K, half* __restrict__ const A, unsigned char * B, float * absmax, const float * datatype, half* out,
1639+ int lda, int ldb, int ldc, int blocksize
1640+ );
1641+ template __global__ void kgemm_4bit_inference_naive<bnb_bfloat16, 64 , 16 >(
1642+ int M, int N, int K, bnb_bfloat16* __restrict__ const A, unsigned char * B, float * absmax, const float * datatype,
1643+ bnb_bfloat16* out, int lda, int ldb, int ldc, int blocksize
1644+ );
1645+ template __global__ void kgemm_4bit_inference_naive<float , 64 , 32 >(
1646+ int M, int N, int K, float * __restrict__ const A, unsigned char * B, float * absmax, const float * datatype,
1647+ float * out, int lda, int ldb, int ldc, int blocksize
1648+ );
15981649
15991650template __global__ void kdequant_mm_int32_fp16<4 , 512 >(
16001651 int * __restrict__ const A, float * __restrict__ const rowStats, float * __restrict__ const colStats, half* out,
0 commit comments