You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
51
+
kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n);
52
52
CUDA_CHECK_RETURN(cudaPeekAtLastError());
53
53
}
54
54
@@ -76,16 +76,16 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
76
76
CUDA_CHECK_RETURN(cudaPeekAtLastError());
77
77
}
78
78
79
-
template<typename T, int DATA_TYPE> voiddequantizeBlockwise(float *code, unsignedchar *A, float *absmax, T *out, int blocksize, constint n)
79
+
template<typename T, int DATA_TYPE> voiddequantizeBlockwise(float *code, unsignedchar *A, float *absmax, T *out, int blocksize, constint n, cudaStream_t stream)
@@ -724,12 +724,11 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
724
724
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
725
725
}
726
726
727
-
template <typename T, int BITS> voidgemm_4bit_inference_naive(int m, int n, int k, T * A, unsignedchar* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
727
+
template <typename T, int BITS> voidgemm_4bit_inference_naive(int m, int n, int k, T * A, unsignedchar* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
728
728
{
729
729
730
730
int num_blocks = (m+3)/4;
731
-
732
-
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
731
+
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsignedchar* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
756
-
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsignedchar* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
757
-
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsignedchar* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
758
-
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsignedchar* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
755
+
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsignedchar* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
756
+
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsignedchar* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
757
+
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsignedchar* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
759
758
760
759
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
761
760
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
voidquantize(float *code, float *A, unsignedchar *out, int n);
145
-
voiddequantize(float *code, unsignedchar *A, float *out, int n);
146
+
voiddequantize(float *code, unsignedchar *A, float *out, int n, cudaStream_t stream);
146
147
template <typename T, int STOCHASTIC, int DATA_TYPE> voidquantizeBlockwise(float * code, T *A, float *absmax, unsignedchar *out, float* rand, int rand_offset, int blocksize, constint n);
147
-
template<typename T, int DATA_TYPE> voiddequantizeBlockwise(float *code, unsignedchar *A, float *absmax, T *out, int block_size, constint n);
148
+
template<typename T, int DATA_TYPE> voiddequantizeBlockwise(float *code, unsignedchar *A, float *absmax, T *out, int block_size, constint n, cudaStream_t stream);
148
149
149
150
template<typename T, int OPTIMIZER> voidoptimizer32bit(T* g, T* p, T* return_updates,
@@ -196,7 +197,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
196
197
197
198
template <typename T> voidgemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
198
199
template <typename T> voidgemm_4bit_inference(int m, int n, int k, T * A, unsignedchar* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
199
-
template <typename T, int BITS> voidgemm_4bit_inference_naive(int m, int n, int k, T * A, unsignedchar* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
200
+
template <typename T, int BITS> voidgemm_4bit_inference_naive(int m, int n, int k, T * A, unsignedchar* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
200
201
201
202
template <typename T, int FUNC> voidfunc(T *A, T *B, T value, long n);
0 commit comments