Skip to content

Commit 6bef412

Browse files
authored
Cuda source cleanup , refactor and fixes (#1328)
* remove kcompress * fix initial template call * fix function name * remove vector load * cleanup reduce & rearrange * format
1 parent 432a4f4 commit 6bef412

3 files changed

Lines changed: 50 additions & 193 deletions

File tree

csrc/kernels.cu

Lines changed: 32 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#define NUM 4
2121
#define NUM_BLOCK 4096
2222

23+
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
2324

2425
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
2526
__device__ float atomicMax(float* address, float val) {
@@ -462,50 +463,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran
462463
}
463464
}
464465

465-
template <int SIGNED>
466-
__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper)
467-
{
468-
int lower_pivot = QUADRANT*16-1 - 0;
469-
int pivot = QUADRANT*16-1 + 16;
470-
int upper_pivot = QUADRANT*16-1 + 31;
471-
472-
float val = midpoint;
473-
474-
// i>>=1 = {32, 16, 8, 4, 2, 1}
475-
for(int i = 16; i > 0; i>>=1)
476-
{
477-
if(x > val)
478-
{
479-
lower_pivot = pivot;
480-
lower = val;
481-
pivot+=i;
482-
}
483-
else
484-
{
485-
upper_pivot = pivot;
486-
upper = val;
487-
pivot-=i;
488-
}
489-
val = smem_code[pivot];
490-
}
491-
492-
if(x > val)
493-
{
494-
midpoint = (upper+val)*0.5f;
495-
if(x > midpoint)
496-
return upper_pivot;
497-
else
498-
return pivot;
499-
}
500-
else
501-
{
502-
midpoint = (lower+val)*0.5f;
503-
if(x < midpoint)
504-
return lower_pivot;
505-
else
506-
return pivot;
507-
}
508-
}
509466

510467
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
511468
{
@@ -519,86 +476,6 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
519476
}
520477
}
521478

522-
template<typename T, int BLOCK_SIZE, int NUM_MAX>
523-
__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
524-
{
525-
typedef cub::WarpReduce<T> WarpReduce;
526-
__shared__ typename WarpReduce::TempStorage temp_storage;
527-
typedef cub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
528-
__shared__ typename LoadT::TempStorage loadt;
529-
530-
const int warp_idx = threadIdx.x/32;
531-
const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE);
532-
533-
// BLOCK_SIZE/32 == number of warps
534-
__shared__ int smem_max_indices[8*BLOCK_SIZE/32];
535-
__shared__ float smem_max_values[8*BLOCK_SIZE/32];
536-
537-
T values[8];
538-
T max1 = -64000.0f;
539-
T max2 = -64000.0f;
540-
int max_idx1 = -1;
541-
int max_idx2 = -1;
542-
int sign1 = -1;
543-
int sign2 = -1;
544-
545-
// 1. load 8 values per thread
546-
// 2. compute 2-max in registers (64 max per warp)
547-
// 3. do warp reduction + broadcast back
548-
// 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
549-
// 5. Repeat (3) 8 times for top 8 values in 256
550-
// 6. store with byte index
551-
552-
LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f);
553-
#pragma unroll 8
554-
for(int i = 0; i < 8; i++)
555-
{
556-
T absval = fabsf(values[i]);
557-
if(absval > max1)
558-
{
559-
max1 = values[i];
560-
sign1 = signbit(values[i]);
561-
max_idx1 = 8*threadIdx.x + i;
562-
}
563-
else if(absval > max2)
564-
{
565-
max2 = values[i];
566-
sign2 = signbit(values[i]);
567-
max_idx2 = 8*threadIdx.x + i;
568-
}
569-
}
570-
571-
float warp_max;
572-
for(int i = 0; i < 8; i++)
573-
{
574-
// 3. do warp reduction + broadcast back
575-
warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max());
576-
warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff);
577-
578-
// 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
579-
if(warp_max == max1)
580-
{
581-
smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1;
582-
smem_max_indices[warp_idx*8 + i] = max_idx1;
583-
584-
sign1 = sign2;
585-
max1 = max2;
586-
max_idx1 = max_idx2;
587-
588-
max2 = -64000.0f;
589-
}
590-
__syncwarp();
591-
}
592-
593-
if(threadIdx.x % 32 < 8)
594-
{
595-
// offset: 8 values per 256 input values
596-
//
597-
int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8;
598-
}
599-
600-
}
601-
602479
#define THREADS_ESTIMATE 512
603480
#define NUM_ESTIMATE 8
604481
#define BLOCK_ESTIMATE 4096
@@ -1560,7 +1437,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
15601437
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
15611438
switch(OPTIMIZER)
15621439
{
1563-
case MOMENTUM:
1440+
case ADAGRAD:
1441+
case MOMENTUM:
15641442
if(step == 1)
15651443
s1_vals[j] = (float)g_vals[j];
15661444
else
@@ -1663,6 +1541,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
16631541

16641542
if(weight_decay > 0.0f) {
16651543
switch(OPTIMIZER) {
1544+
case ADAGRAD:
16661545
case MOMENTUM:
16671546
case RMSPROP:
16681547
g_val += ((float)p_vals[j])*weight_decay;
@@ -1675,8 +1554,8 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
16751554

16761555
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
16771556

1678-
switch(OPTIMIZER)
1679-
{
1557+
switch(OPTIMIZER){
1558+
case ADAGRAD:
16801559
case MOMENTUM:
16811560
if(step == 1)
16821561
s1_vals[j] = g_vals[j];
@@ -3055,45 +2934,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
30552934
}
30562935
}
30572936

3058-
3059-
//template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
3060-
//{
3061-
//// element-wise kernel
3062-
//// 1. Load batch x k into registers
3063-
//// 2. Load k x k into registers
3064-
//// 3. dequantize and store in second pair of k x k
3065-
//// 4. matmul
3066-
//// 5. sum with cub
3067-
//// 6. store outputs
3068-
//// TC kernel
3069-
//// use k warps per thread block
3070-
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
3071-
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
3072-
//// 3. each warp reads a segment of values 16x32 from B
3073-
//// 4. do dequantization from register of B into second pair of registers
3074-
//// 5. store (4) into fragment
3075-
//// 6. matmul aggregate into fragment C
3076-
//// 7. aggregate files of C into shared memory block C
3077-
//// 8. sum (7)
3078-
//// 9. write outputs to matmul output matrix
3079-
//}
3080-
3081-
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f)
3082-
{
3083-
if(limit_base + ITEMS <= limit)
3084-
reinterpret_cast<TCAST*>(local)[0] = reinterpret_cast<TCAST*>(buffer)[idx/ITEMS];
3085-
else
3086-
{
3087-
for(int k = 0; k < ITEMS; k++)
3088-
{
3089-
if(limit_base + k < limit)
3090-
local[k] = buffer[idx+k];
3091-
else
3092-
local[k] = (T)zero_value;
3093-
}
3094-
}
3095-
}
3096-
30972937
#define WARPS 3
30982938
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
30992939
{
@@ -3311,13 +3151,28 @@ template <typename T> __device__ void printnonzero(T *A, int num_values, const c
33113151
printf("%s %i %f\n", strval, i, (float)A[i]);
33123152
}
33133153

3314-
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
3315-
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
33163154

3317-
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
33183155
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
33193156
{
33203157

3158+
//// element-wise kernel
3159+
//// 1. Load batch x k into registers
3160+
//// 2. Load k x k into registers
3161+
//// 3. dequantize and store in second pair of k x k
3162+
//// 4. matmul
3163+
//// 5. sum with cub
3164+
//// 6. store outputs
3165+
//// TC kernel
3166+
//// use k warps per thread block
3167+
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
3168+
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
3169+
//// 3. each warp reads a segment of values 16x32 from B
3170+
//// 4. do dequantization from register of B into second pair of registers
3171+
//// 5. store (4) into fragment
3172+
//// 6. matmul aggregate into fragment C
3173+
//// 7. aggregate files of C into shared memory block C
3174+
//// 8. sum (7)
3175+
//// 9. write outputs to matmul output matrix
33213176
#if __CUDA_ARCH__ >= 750
33223177
using namespace nvcuda;
33233178
int col_offset = blockIdx.x *32;
@@ -3911,6 +3766,8 @@ MAKE_PreconditionStatic8bit1State(RMSPROP, half)
39113766
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
39123767
MAKE_PreconditionStatic8bit1State(LION, half)
39133768
MAKE_PreconditionStatic8bit1State(LION, float)
3769+
MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
3770+
MAKE_PreconditionStatic8bit1State(ADAGRAD, float)
39143771

39153772
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
39163773
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
@@ -3930,6 +3787,9 @@ MAKE_optimizerStatic8bit1State(RMSPROP, half)
39303787
MAKE_optimizerStatic8bit1State(RMSPROP, float)
39313788
MAKE_optimizerStatic8bit1State(LION, half)
39323789
MAKE_optimizerStatic8bit1State(LION, float)
3790+
MAKE_optimizerStatic8bit1State(ADAGRAD, half)
3791+
MAKE_optimizerStatic8bit1State(ADAGRAD, float)
3792+
39333793

39343794
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
39353795
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
@@ -4075,3 +3935,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
40753935
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
40763936
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
40773937
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
3938+
3939+
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
3940+
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);

csrc/kernels.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#ifndef kernels
1010
#define kernels
1111

12-
//template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
1312

1413
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
1514

csrc/ops.cu

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,6 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
9191
}
9292

9393

94-
//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
95-
//{
96-
// int num_blocks = (colsB+32-1)/32;
97-
// kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
98-
// CUDA_CHECK_RETURN(cudaPeekAtLastError());
99-
//}
100-
10194

10295
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
10396
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
@@ -362,10 +355,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
362355
}
363356
}
364357

365-
template int get_leading_dim<ROW>(int dim1, int dim2);
366-
template int get_leading_dim<COL>(int dim1, int dim2);
367-
template int get_leading_dim<COL32>(int dim1, int dim2);
368-
369358
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
370359
{
371360
#ifdef NO_CUBLASLT
@@ -411,15 +400,6 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
411400
#endif
412401
}
413402

414-
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
415-
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
416-
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
417-
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
418-
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
419-
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
420-
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
421-
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
422-
423403
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
424404
{
425405
#ifdef NO_CUBLASLT
@@ -693,9 +673,9 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
693673
//cout << m << endl;
694674
//cout << n << endl;
695675
//cout << k << endl;
696-
//if(bits == 32)
676+
if(bits == 32)
697677
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
698-
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
678+
gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
699679
if(bits == 16)
700680
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
701681
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
@@ -841,6 +821,9 @@ MAKE_optimizerStatic8bit(RMSPROP, half)
841821
MAKE_optimizerStatic8bit(RMSPROP, float)
842822
MAKE_optimizerStatic8bit(LION, half)
843823
MAKE_optimizerStatic8bit(LION, float)
824+
MAKE_optimizerStatic8bit(ADAGRAD, half)
825+
MAKE_optimizerStatic8bit(ADAGRAD, float)
826+
844827

845828
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
846829
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
@@ -849,6 +832,7 @@ template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g
849832

850833
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
851834
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
835+
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
852836
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
853837
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
854838
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
@@ -862,4 +846,15 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
862846
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
863847
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
864848

865-
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
849+
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
850+
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
851+
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
852+
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
853+
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
854+
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
855+
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
856+
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
857+
858+
template int get_leading_dim<ROW>(int dim1, int dim2);
859+
template int get_leading_dim<COL>(int dim1, int dim2);
860+
template int get_leading_dim<COL32>(int dim1, int dim2);

0 commit comments

Comments
 (0)