Skip to content

Commit 6aeade4

Browse files
committed
Optimize fused gemm+dequant kernel for ROCm, use it for batch sizes other than 1
1 parent 74994ef commit 6aeade4

5 files changed

Lines changed: 156 additions & 95 deletions

File tree

bitsandbytes/_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def _(
282282
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
283283
) -> torch.Tensor:
284284
torch._check_is_size(blocksize)
285-
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
286285
torch._check(
287286
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
288287
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
@@ -312,7 +311,6 @@ def _(
312311
out: torch.Tensor,
313312
) -> None:
314313
torch._check_is_size(blocksize)
315-
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
316314
torch._check(
317315
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
318316
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",

bitsandbytes/autograd/_functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ def matmul(
374374
return MatMul8bitLt.apply(A, B, out, bias, state)
375375

376376

377+
# Above this limit, inference falls back to the dequantize + GEMM path.
378+
FUSED_4BIT_DEQUANT_LIMIT = 8
379+
380+
377381
def matmul_4bit(
378382
A: torch.Tensor,
379383
B: torch.Tensor,
@@ -391,7 +395,8 @@ def matmul_4bit(
391395
else:
392396
return MatMul4Bit.apply(A, B, out, bias, quant_state)
393397

394-
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
398+
num_a_rows = A.numel() // A.shape[-1]
399+
if num_a_rows <= FUSED_4BIT_DEQUANT_LIMIT and A.requires_grad == False and A.device.type != "hpu":
395400
if A.shape[-1] % quant_state.blocksize != 0:
396401
warn(
397402
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/cuda/ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,11 @@ def _gemv_4bit_impl(
472472
# torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
473473

474474
m = ct.c_int32(shapeB[0])
475-
n = ct.c_int32(1)
475+
num_a_rows = A.numel() // A.shape[-1]
476+
n = ct.c_int32(num_a_rows)
476477
k = ct.c_int32(shapeB[1])
477478

478-
lda = m
479+
lda = ct.c_int32(A.shape[-1])
479480
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
480481
ldc = m
481482

csrc/kernels.cu

Lines changed: 136 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,116 +1446,155 @@ __global__ void kgemm_4bit_inference_naive(
14461446
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
14471447
int lda, int ldb, int ldc, int blocksize
14481448
) {
1449-
1450-
// per threadblock:
1451-
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
1452-
// THREADS/BNB_WARP_SIZE warps -> that many loads per iter
1453-
// 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block
14541449
typedef bnb_cub::WarpReduce<float> WarpReduce;
14551450
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE];
14561451

14571452
const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
14581453
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
14591454
const int row_B = (THREADS / BNB_WARP_SIZE) * blockIdx.x + warp_idx;
14601455
const int offset_B = ldb * row_B;
1461-
const int num_values_8bit = num_values_4bit / 2;
1462-
float local_C = 0.0f;
1456+
constexpr int num_values_8bit = num_values_4bit / 2;
1457+
1458+
float local_C0 = 0.0f;
1459+
float local_C1 = 0.0f;
1460+
float local_C2 = 0.0f;
1461+
float local_C3 = 0.0f;
14631462

14641463
unsigned char local_B_4bit[num_values_8bit];
1465-
T local_B[num_values_4bit / 4];
1466-
T local_A[num_values_4bit / 4];
1467-
__shared__ T quant_map[16];
1468-
T local_absmax = T(0.0f);
1469-
1470-
if (threadIdx.x < 16)
1471-
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
1472-
// for(int i = threadIdx.x; i < 16; i++)
1473-
// quant_map[i] = T(__ldg(&datatype[i]));
1464+
__shared__ float quant_map[32];
1465+
float local_absmax = 0.0f;
1466+
1467+
if (threadIdx.x < 16) {
1468+
float val = __ldg(&datatype[threadIdx.x]);
1469+
quant_map[threadIdx.x] = val;
1470+
quant_map[threadIdx.x + 16] = val;
1471+
}
14741472
__syncthreads();
14751473

1476-
// A: [1, K]
1477-
// B: [N, K]
1478-
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) {
1479-
const int inner_idx_halved = inner_idx / 2;
1474+
if (row_B >= M) return;
14801475

1481-
// Since blocksize will always be a power-of-2, we avoid more expensive
1482-
// division by the blocksize and instead use a shift operation.
1483-
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
1484-
const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize));
1476+
const int stride = BNB_WARP_SIZE * num_values_4bit;
1477+
const int clz_blocksize = 31 - __clz(blocksize);
1478+
const int base_absidx = 2 * offset_B;
1479+
const int qm_offset = (warp_lane & 1) << 4;
14851480

1486-
local_absmax = __ldg(&(absmax[absidx]));
1481+
for (int n_idx = 0; n_idx < N; n_idx++) {
1482+
const T* __restrict__ A_row = A + n_idx * lda;
14871483

1488-
if (row_B < M) {
1489-
if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
1490-
// this is the most important for performance considerations
1491-
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] =
1492-
reinterpret_cast<int4*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
1493-
} else {
1494-
#pragma unroll
1495-
for (int j = 0; j < (num_values_8bit); j++)
1496-
if ((inner_idx_halved) + j < (K / 2))
1497-
local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
1498-
else
1499-
local_B_4bit[j] = 0b01110111;
1500-
}
1501-
} else {
1502-
#pragma unroll
1503-
for (int j = 0; j < (num_values_8bit); j++)
1504-
local_B_4bit[j] = 0b01110111;
1484+
local_C0 = 0.0f;
1485+
local_C1 = 0.0f;
1486+
local_C2 = 0.0f;
1487+
local_C3 = 0.0f;
1488+
1489+
int inner_idx = warp_lane * num_values_4bit;
1490+
int inner_idx_halved = inner_idx >> 1;
1491+
int4 prefetch_B;
1492+
float prefetch_absmax;
1493+
1494+
if (inner_idx < K) {
1495+
prefetch_absmax = __ldg(&absmax[(base_absidx + inner_idx) >> clz_blocksize]);
1496+
if ((inner_idx_halved + num_values_8bit) < (K >> 1))
1497+
prefetch_B = reinterpret_cast<int4*>(B)[(offset_B + inner_idx_halved) / num_values_8bit];
15051498
}
15061499

1507-
for (int i = 0; i < 4; i++) {
1508-
#pragma unroll
1509-
for (int k = 0; k < num_values_8bit / 4; k++) {
1510-
#if BNB_BF16_AVAILABLE
1511-
local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
1512-
local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
1513-
#else
1514-
// bf16 multipliation not supported
1515-
local_B[k * 2] =
1516-
T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax);
1517-
local_B[k * 2 + 1] =
1518-
T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax);
1519-
#endif
1520-
}
1500+
for (; inner_idx < K; inner_idx += stride) {
1501+
inner_idx_halved = inner_idx >> 1;
15211502

1522-
if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
1523-
// this is also relatively important for performance
1524-
if (BITS == 16) {
1525-
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
1526-
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 4) + i];
1527-
} else {
1528-
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
1529-
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
1530-
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] =
1531-
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
1532-
}
1503+
local_absmax = prefetch_absmax;
15331504

1534-
} else
1505+
if (__builtin_expect((inner_idx_halved + num_values_8bit) < (K >> 1), 1)) {
1506+
reinterpret_cast<int4&>(local_B_4bit[0]) = prefetch_B;
1507+
} else {
15351508
#pragma unroll
1536-
for (int k = 0; k < num_values_4bit / 4; k++)
1537-
if (inner_idx + (i * num_values_4bit / 4) + k < K)
1538-
local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
1539-
else
1540-
local_A[k] = T(0.0f);
1509+
for (int j = 0; j < num_values_8bit; j++)
1510+
local_B_4bit[j] = ((inner_idx_halved + j) < (K >> 1)) ? B[offset_B + inner_idx_halved + j] : 0x77;
1511+
}
1512+
1513+
int next_inner_idx = inner_idx + stride;
1514+
int next_inner_idx_halved = next_inner_idx >> 1;
1515+
if (next_inner_idx < K) {
1516+
prefetch_absmax = __ldg(&absmax[(base_absidx + next_inner_idx) >> clz_blocksize]);
1517+
if ((next_inner_idx_halved + num_values_8bit) < (K >> 1))
1518+
prefetch_B = reinterpret_cast<int4*>(B)[(offset_B + next_inner_idx_halved) / num_values_8bit];
1519+
}
15411520

1542-
// accumulate in float; small performance hit for Ampere, but lower error for outputs
1521+
float b0 = quant_map[qm_offset + (local_B_4bit[0] >> 4)] * local_absmax;
1522+
float b1 = quant_map[qm_offset + (local_B_4bit[0] & 0xF)] * local_absmax;
1523+
float b2 = quant_map[qm_offset + (local_B_4bit[1] >> 4)] * local_absmax;
1524+
float b3 = quant_map[qm_offset + (local_B_4bit[1] & 0xF)] * local_absmax;
1525+
float b4 = quant_map[qm_offset + (local_B_4bit[2] >> 4)] * local_absmax;
1526+
float b5 = quant_map[qm_offset + (local_B_4bit[2] & 0xF)] * local_absmax;
1527+
float b6 = quant_map[qm_offset + (local_B_4bit[3] >> 4)] * local_absmax;
1528+
float b7 = quant_map[qm_offset + (local_B_4bit[3] & 0xF)] * local_absmax;
1529+
float b8 = quant_map[qm_offset + (local_B_4bit[4] >> 4)] * local_absmax;
1530+
float b9 = quant_map[qm_offset + (local_B_4bit[4] & 0xF)] * local_absmax;
1531+
float b10 = quant_map[qm_offset + (local_B_4bit[5] >> 4)] * local_absmax;
1532+
float b11 = quant_map[qm_offset + (local_B_4bit[5] & 0xF)] * local_absmax;
1533+
float b12 = quant_map[qm_offset + (local_B_4bit[6] >> 4)] * local_absmax;
1534+
float b13 = quant_map[qm_offset + (local_B_4bit[6] & 0xF)] * local_absmax;
1535+
float b14 = quant_map[qm_offset + (local_B_4bit[7] >> 4)] * local_absmax;
1536+
float b15 = quant_map[qm_offset + (local_B_4bit[7] & 0xF)] * local_absmax;
1537+
float b16 = quant_map[qm_offset + (local_B_4bit[8] >> 4)] * local_absmax;
1538+
float b17 = quant_map[qm_offset + (local_B_4bit[8] & 0xF)] * local_absmax;
1539+
float b18 = quant_map[qm_offset + (local_B_4bit[9] >> 4)] * local_absmax;
1540+
float b19 = quant_map[qm_offset + (local_B_4bit[9] & 0xF)] * local_absmax;
1541+
float b20 = quant_map[qm_offset + (local_B_4bit[10] >> 4)] * local_absmax;
1542+
float b21 = quant_map[qm_offset + (local_B_4bit[10] & 0xF)] * local_absmax;
1543+
float b22 = quant_map[qm_offset + (local_B_4bit[11] >> 4)] * local_absmax;
1544+
float b23 = quant_map[qm_offset + (local_B_4bit[11] & 0xF)] * local_absmax;
1545+
float b24 = quant_map[qm_offset + (local_B_4bit[12] >> 4)] * local_absmax;
1546+
float b25 = quant_map[qm_offset + (local_B_4bit[12] & 0xF)] * local_absmax;
1547+
float b26 = quant_map[qm_offset + (local_B_4bit[13] >> 4)] * local_absmax;
1548+
float b27 = quant_map[qm_offset + (local_B_4bit[13] & 0xF)] * local_absmax;
1549+
float b28 = quant_map[qm_offset + (local_B_4bit[14] >> 4)] * local_absmax;
1550+
float b29 = quant_map[qm_offset + (local_B_4bit[14] & 0xF)] * local_absmax;
1551+
float b30 = quant_map[qm_offset + (local_B_4bit[15] >> 4)] * local_absmax;
1552+
float b31 = quant_map[qm_offset + (local_B_4bit[15] & 0xF)] * local_absmax;
1553+
1554+
if (__builtin_expect(inner_idx + 32 <= K, 1)) {
1555+
int4 a_vec0 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8];
1556+
int4 a_vec1 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 1];
1557+
int4 a_vec2 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 2];
1558+
int4 a_vec3 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 3];
1559+
1560+
const T* a0 = reinterpret_cast<const T*>(&a_vec0);
1561+
const T* a1 = reinterpret_cast<const T*>(&a_vec1);
1562+
const T* a2 = reinterpret_cast<const T*>(&a_vec2);
1563+
const T* a3 = reinterpret_cast<const T*>(&a_vec3);
1564+
1565+
local_C0 += (float)a0[0]*b0; local_C1 += (float)a0[1]*b1;
1566+
local_C2 += (float)a0[2]*b2; local_C3 += (float)a0[3]*b3;
1567+
local_C0 += (float)a0[4]*b4; local_C1 += (float)a0[5]*b5;
1568+
local_C2 += (float)a0[6]*b6; local_C3 += (float)a0[7]*b7;
1569+
local_C0 += (float)a1[0]*b8; local_C1 += (float)a1[1]*b9;
1570+
local_C2 += (float)a1[2]*b10; local_C3 += (float)a1[3]*b11;
1571+
local_C0 += (float)a1[4]*b12; local_C1 += (float)a1[5]*b13;
1572+
local_C2 += (float)a1[6]*b14; local_C3 += (float)a1[7]*b15;
1573+
local_C0 += (float)a2[0]*b16; local_C1 += (float)a2[1]*b17;
1574+
local_C2 += (float)a2[2]*b18; local_C3 += (float)a2[3]*b19;
1575+
local_C0 += (float)a2[4]*b20; local_C1 += (float)a2[5]*b21;
1576+
local_C2 += (float)a2[6]*b22; local_C3 += (float)a2[7]*b23;
1577+
local_C0 += (float)a3[0]*b24; local_C1 += (float)a3[1]*b25;
1578+
local_C2 += (float)a3[2]*b26; local_C3 += (float)a3[3]*b27;
1579+
local_C0 += (float)a3[4]*b28; local_C1 += (float)a3[5]*b29;
1580+
local_C2 += (float)a3[6]*b30; local_C3 += (float)a3[7]*b31;
1581+
} else {
1582+
float b_vals[32] = {b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15,
1583+
b16,b17,b18,b19,b20,b21,b22,b23,b24,b25,b26,b27,b28,b29,b30,b31};
15431584
#pragma unroll
1544-
for (int k = 0; k < num_values_4bit / 4; k++) {
1545-
#if BNB_BF16_AVAILABLE
1546-
local_C += (float)(local_A[k] * local_B[k]);
1547-
#else
1548-
// bf16 multipliation not supported
1549-
local_C += ((float)local_A[k] * (float)local_B[k]);
1550-
#endif
1585+
for (int k = 0; k < 32; k++) {
1586+
float a_val = (inner_idx + k < K) ? (float)A_row[inner_idx + k] : 0.0f;
1587+
local_C0 += a_val * b_vals[k];
1588+
}
15511589
}
15521590
}
1553-
}
15541591

1555-
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
1592+
float local_C = local_C0 + local_C1 + local_C2 + local_C3;
1593+
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
15561594

1557-
if (row_B < M && warp_lane == 0)
1558-
out[row_B] = T(local_C);
1595+
if (warp_lane == 0)
1596+
out[n_idx * ldc + row_B] = T(local_C);
1597+
}
15591598
}
15601599

15611600
template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n) {
@@ -1595,6 +1634,18 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(
15951634
int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
15961635
float* out, int lda, int ldb, int ldc, int blocksize
15971636
);
1637+
template __global__ void kgemm_4bit_inference_naive<half, 64, 16>(
1638+
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out,
1639+
int lda, int ldb, int ldc, int blocksize
1640+
);
1641+
template __global__ void kgemm_4bit_inference_naive<bnb_bfloat16, 64, 16>(
1642+
int M, int N, int K, bnb_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
1643+
bnb_bfloat16* out, int lda, int ldb, int ldc, int blocksize
1644+
);
1645+
template __global__ void kgemm_4bit_inference_naive<float, 64, 32>(
1646+
int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
1647+
float* out, int lda, int ldb, int ldc, int blocksize
1648+
);
15981649

15991650
template __global__ void kdequant_mm_int32_fp16<4, 512>(
16001651
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,

csrc/ops.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,15 +421,21 @@ void gemm_4bit_inference_naive(
421421
int blocksize, bnb_stream_t stream
422422
) {
423423

424-
int num_blocks = (m + 3) / 4;
425424
#if BNB_HIP
426-
if (bnb_host_warp_size() == 64) {
427-
num_blocks = (m + 1) / 2;
425+
const int ws = bnb_host_warp_size();
426+
int num_blocks = (m + 1) / 2;
427+
if (ws == 32) {
428+
kgemm_4bit_inference_naive<T, 64, BITS>
429+
<<<num_blocks, 64, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
430+
} else {
431+
kgemm_4bit_inference_naive<T, 128, BITS>
432+
<<<num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
428433
}
429-
#endif
430-
434+
#else
435+
int num_blocks = (m + 3) / 4;
431436
kgemm_4bit_inference_naive<T, 128, BITS>
432437
<<<num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
438+
#endif
433439
BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR());
434440
}
435441

0 commit comments

Comments
 (0)