Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions csrc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
// Warp size

#if BNB_HIP
// CDNA (gfx9xx) = 64, RDNA = 32.
// CDNA (gfx9xx) = 64, RDNA (gfx10xx/gfx11xx/gfx12xx) = 32.
// __AMDGCN_WAVEFRONT_SIZE is not defined by all compiler versions (removed since ROCm 7.0),
// so fall back to architecture-family macros when it is absent.
// This is a macro that is defined by the compiler during each device-code pass and as such should only be used inside kernels.
#ifdef __AMDGCN_WAVEFRONT_SIZE
#define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
#elif defined(__GFX9__)
#define BNB_WARP_SIZE 64 // CDNA
#else
#define BNB_WARP_SIZE 64 // Safe default for HIP (matches CDNA)
#define BNB_WARP_SIZE 32 // RDNA and other
#endif
#else
#define BNB_WARP_SIZE 32
Expand Down
32 changes: 27 additions & 5 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@

#define ERR_NOT_IMPLEMENTED 100

#if BNB_HIP
#include <hip/hip_runtime.h>
static int bnb_host_warp_size() {
Comment thread
sstamenk marked this conversation as resolved.
constexpr int MAX_DEVICES = 32;
static int cache[MAX_DEVICES] = {};
int dev;
(void)hipGetDevice(&dev);
if (dev < 0 || dev >= MAX_DEVICES) return 64;
if (cache[dev] == 0)
(void)hipDeviceGetAttribute(&cache[dev], hipDeviceAttributeWarpSize, dev);
return cache[dev];
}
#else
static constexpr int bnb_host_warp_size() { return 32; }
#endif


using std::cout;
using std::endl;

Expand All @@ -35,10 +52,16 @@ void quantizeBlockwise(
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 64) {
#if BNB_HIP
// On HIP with 64-wide warps (CDNA), use specialized kernel for 4-bit types
if constexpr (DATA_TYPE > 0) {
kQuantizeBlockwiseSmall<T, DATA_TYPE>
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
if (bnb_host_warp_size() == 64) {
// CDNA: kQuantizeBlockwiseSmall is compiled with THREADS=64
kQuantizeBlockwiseSmall<T, DATA_TYPE>
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
} else {
// RDNA: standard kernel (same as CUDA path)
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>
<<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
}
} else {
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
}
Expand Down Expand Up @@ -407,8 +430,7 @@ void gemm_4bit_inference_naive(

int num_blocks = (m + 3) / 4;
#if BNB_HIP
// On 64-wide warp architectures, each warp processes 2 rows instead of 4
if (BNB_WARP_SIZE == 64) {
if (bnb_host_warp_size() == 64) {
num_blocks = (m + 1) / 2;
}
#endif
Expand Down
Loading