You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
//template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
3060
-
//{
3061
-
//// element-wise kernel
3062
-
//// 1. Load batch x k into registers
3063
-
//// 2. Load k x k into registers
3064
-
//// 3. dequantize and store in second pair of k x k
3065
-
//// 4. matmul
3066
-
//// 5. sum with cub
3067
-
//// 6. store outputs
3068
-
//// TC kernel
3069
-
//// use k warps per thread block
3070
-
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
3071
-
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
3072
-
//// 3. each warp reads a segment of values 16x32 from B
3073
-
//// 4. do dequantization from register of B into second pair of registers
3074
-
//// 5. store (4) into fragment
3075
-
//// 6. matmul aggregate into fragment C
3076
-
//// 7. aggregate files of C into shared memory block C
3077
-
//// 8. sum (7)
3078
-
//// 9. write outputs to matmul output matrix
3079
-
//}
3080
-
3081
-
template <typename T, typename TCAST, int ITEMS> __device__inlinevoidvector_load(T *local, T * __restrict__const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f)
template <typename T, int BITS, int THREADS> __global__voidgemm_device(int M, int N, int K, T * __restrict__const A, T* B, T * out, int lda, int ldb, int ldc)
3099
2939
{
@@ -3311,13 +3151,28 @@ template <typename T> __device__ void printnonzero(T *A, int num_values, const c
3311
3151
printf("%s %i %f\n", strval, i, (float)A[i]);
3312
3152
}
3313
3153
3314
-
template __device__void printnonzero<float>(float *A, int num_values, constchar*strval);
3315
-
template __device__void printnonzero<half>(half *A, int num_values, constchar*strval);
template <typename T, int THREADS> __global__voidkgemm_4bit_inference(int M, int N, int K, T * __restrict__const A, unsignedchar *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
3319
3156
{
3320
3157
3158
+
//// element-wise kernel
3159
+
//// 1. Load batch x k into registers
3160
+
//// 2. Load k x k into registers
3161
+
//// 3. dequantize and store in second pair of k x k
3162
+
//// 4. matmul
3163
+
//// 5. sum with cub
3164
+
//// 6. store outputs
3165
+
//// TC kernel
3166
+
//// use k warps per thread block
3167
+
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
3168
+
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
3169
+
//// 3. each warp reads a segment of values 16x32 from B
3170
+
//// 4. do dequantization from register of B into second pair of registers
3171
+
//// 5. store (4) into fragment
3172
+
//// 6. matmul aggregate into fragment C
3173
+
//// 7. aggregate files of C into shared memory block C
@@ -362,10 +355,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
362
355
}
363
356
}
364
357
365
-
template int get_leading_dim<ROW>(int dim1, int dim2);
366
-
template int get_leading_dim<COL>(int dim1, int dim2);
367
-
template int get_leading_dim<COL32>(int dim1, int dim2);
368
-
369
358
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> voidtransform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
370
359
{
371
360
#ifdef NO_CUBLASLT
@@ -411,15 +400,6 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
411
400
#endif
412
401
}
413
402
414
-
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
415
-
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
416
-
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
417
-
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
418
-
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
419
-
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
420
-
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
421
-
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
422
-
423
403
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> intigemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
424
404
{
425
405
#ifdef NO_CUBLASLT
@@ -693,9 +673,9 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
693
673
//cout << m << endl;
694
674
//cout << n << endl;
695
675
//cout << k << endl;
696
-
//if(bits == 32)
676
+
if(bits == 32)
697
677
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
698
-
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
678
+
gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
699
679
if(bits == 16)
700
680
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
701
681
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
0 commit comments