Skip to content

Commit 2fa43fe

Browse files
jeejeeleematthewdouglas
authored andcommitted
Enable certain CUDA kernels to accept specified cuda stream (bitsandbytes-foundation#1330)
* Done * fix format * fix format * fix format * fix format * Address format error and fix default arg bug * Refine stream argument passing mechanism * Fix bug * Delete unused code
1 parent fc32ff0 commit 2fa43fe

4 files changed

Lines changed: 77 additions & 58 deletions

File tree

bitsandbytes/functional.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,11 @@ def is_on_gpu(tensors):
439439
return on_gpu
440440

441441

442+
def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
443+
stream = torch.cuda.current_stream(tensor.device)
444+
return stream
445+
446+
442447
def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
443448
"""
444449
Get the ctypes pointer from a PyTorch Tensor.
@@ -973,6 +978,7 @@ def dequantize_blockwise(
973978
f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
974979
)
975980
is_on_gpu([A, absmax, out])
981+
stream = get_tensor_stream(A)
976982
if out.dtype == torch.float32:
977983
lib.cdequantize_blockwise_fp32(
978984
get_ptr(quant_state.code),
@@ -981,6 +987,7 @@ def dequantize_blockwise(
981987
get_ptr(out),
982988
ct.c_int(quant_state.blocksize),
983989
ct.c_int(A.numel()),
990+
stream, # Used the _as_parameter_ attribute of torch.cuda.Stream, Similarly for the following
984991
)
985992
elif out.dtype == torch.float16:
986993
lib.cdequantize_blockwise_fp16(
@@ -990,6 +997,7 @@ def dequantize_blockwise(
990997
get_ptr(out),
991998
ct.c_int(quant_state.blocksize),
992999
ct.c_int(A.numel()),
1000+
stream,
9931001
)
9941002
elif out.dtype == torch.bfloat16:
9951003
lib.cdequantize_blockwise_bf16(
@@ -999,6 +1007,7 @@ def dequantize_blockwise(
9991007
get_ptr(out),
10001008
ct.c_int(quant_state.blocksize),
10011009
ct.c_int(A.numel()),
1010+
stream,
10021011
)
10031012
else:
10041013
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
@@ -1176,7 +1185,6 @@ def quantize_4bit(
11761185

11771186
prev_device = pre_call(A.device)
11781187
is_on_gpu([A, out, absmax])
1179-
11801188
if A.dtype == torch.float32:
11811189
if quant_type == "fp4":
11821190
lib.cquantize_blockwise_fp32_fp4(
@@ -1356,6 +1364,7 @@ def dequantize_4bit(
13561364

13571365
device = pre_call(A.device)
13581366
is_on_gpu([A, absmax, out])
1367+
stream = get_tensor_stream(A)
13591368
if out.dtype == torch.float32:
13601369
if quant_state.quant_type == "fp4":
13611370
lib.cdequantize_blockwise_fp32_fp4(
@@ -1365,6 +1374,7 @@ def dequantize_4bit(
13651374
get_ptr(out),
13661375
ct.c_int(quant_state.blocksize),
13671376
ct.c_int(n),
1377+
stream,
13681378
)
13691379
else:
13701380
lib.cdequantize_blockwise_fp32_nf4(
@@ -1374,6 +1384,7 @@ def dequantize_4bit(
13741384
get_ptr(out),
13751385
ct.c_int(quant_state.blocksize),
13761386
ct.c_int(n),
1387+
stream,
13771388
)
13781389
elif out.dtype == torch.float16:
13791390
if quant_state.quant_type == "fp4":
@@ -1384,6 +1395,7 @@ def dequantize_4bit(
13841395
get_ptr(out),
13851396
ct.c_int(quant_state.blocksize),
13861397
ct.c_int(n),
1398+
stream,
13871399
)
13881400
else:
13891401
lib.cdequantize_blockwise_fp16_nf4(
@@ -1393,6 +1405,7 @@ def dequantize_4bit(
13931405
get_ptr(out),
13941406
ct.c_int(quant_state.blocksize),
13951407
ct.c_int(n),
1408+
stream,
13961409
)
13971410
elif out.dtype == torch.bfloat16:
13981411
if quant_state.quant_type == "fp4":
@@ -1403,6 +1416,7 @@ def dequantize_4bit(
14031416
get_ptr(out),
14041417
ct.c_int(quant_state.blocksize),
14051418
ct.c_int(n),
1419+
stream,
14061420
)
14071421
else:
14081422
lib.cdequantize_blockwise_bf16_nf4(
@@ -1412,6 +1426,7 @@ def dequantize_4bit(
14121426
get_ptr(out),
14131427
ct.c_int(quant_state.blocksize),
14141428
ct.c_int(n),
1429+
stream,
14151430
)
14161431
else:
14171432
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
@@ -1518,7 +1533,8 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
15181533
if out is None:
15191534
out = torch.zeros_like(A, dtype=torch.float32)
15201535
is_on_gpu([code, A, out])
1521-
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
1536+
stream = get_tensor_stream(A)
1537+
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream)
15221538
post_call(prev_device)
15231539
return out
15241540

@@ -2013,7 +2029,7 @@ def gemv_4bit(
20132029
lda = ct.c_int32(lda)
20142030
ldb = ct.c_int32(ldb)
20152031
ldc = ct.c_int32(ldc)
2016-
2032+
stream = get_tensor_stream(A)
20172033
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
20182034
if A.dtype == torch.float16:
20192035
lib.cgemm_4bit_inference_naive_fp16(
@@ -2029,6 +2045,7 @@ def gemv_4bit(
20292045
ldb,
20302046
ldc,
20312047
ct.c_int32(state.blocksize),
2048+
stream,
20322049
)
20332050
elif A.dtype == torch.bfloat16:
20342051
lib.cgemm_4bit_inference_naive_bf16(
@@ -2044,6 +2061,7 @@ def gemv_4bit(
20442061
ldb,
20452062
ldc,
20462063
ct.c_int32(state.blocksize),
2064+
stream,
20472065
)
20482066
elif A.dtype == torch.float32:
20492067
lib.cgemm_4bit_inference_naive_fp32(
@@ -2059,6 +2077,7 @@ def gemv_4bit(
20592077
ldb,
20602078
ldc,
20612079
ct.c_int32(state.blocksize),
2080+
stream,
20622081
)
20632082
else:
20642083
raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")

csrc/ops.cu

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ void quantize(float *code, float *A, unsigned char *out, int n)
4444
CUDA_CHECK_RETURN(cudaPeekAtLastError());
4545
}
4646

47-
void dequantize(float *code, unsigned char *A, float *out, int n)
47+
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream)
4848
{
4949
int num_blocks = n/1024;
5050
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
51-
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
51+
kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n);
5252
CUDA_CHECK_RETURN(cudaPeekAtLastError());
5353
}
5454

@@ -76,16 +76,16 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
7676
CUDA_CHECK_RETURN(cudaPeekAtLastError());
7777
}
7878

79-
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
79+
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream)
8080
{
81+
// printf("stream==%d\n",stream);
8182
int num_blocks = n/blocksize;
8283
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
8384
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
84-
8585
if(DATA_TYPE > 0)
86-
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
86+
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n);
8787
else
88-
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
88+
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);
8989

9090
CUDA_CHECK_RETURN(cudaPeekAtLastError());
9191
}
@@ -724,12 +724,11 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
724724
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
725725
}
726726

727-
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
727+
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
728728
{
729729

730730
int num_blocks = (m+3)/4;
731-
732-
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
731+
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
733732
CUDA_CHECK_RETURN(cudaPeekAtLastError());
734733
}
735734

@@ -753,9 +752,9 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
753752
template void func<float, _MUL>(float *A, float *B, float value, long n);
754753

755754
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
756-
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
757-
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
758-
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
755+
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
756+
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
757+
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
759758

760759
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
761760
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
@@ -795,15 +794,15 @@ template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __n
795794
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
796795
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
797796

798-
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
799-
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
800-
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
801-
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
802-
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
803-
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
804-
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
805-
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
806-
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
797+
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
798+
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
799+
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
800+
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
801+
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
802+
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
803+
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
804+
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
805+
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
807806

808807
#define MAKE_optimizer32bit(name, gtype) \
809808
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, gtype* return_updates, \

csrc/ops.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#ifndef ops_H
88
#define ops_H
99

10+
#include <cstdint>
1011
#include <stdio.h>
1112
#include <iostream>
1213
#include <assert.h>
@@ -142,9 +143,9 @@ class ContextCusparse
142143
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
143144

144145
void quantize(float *code, float *A, unsigned char *out, int n);
145-
void dequantize(float *code, unsigned char *A, float *out, int n);
146+
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream);
146147
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
147-
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
148+
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream);
148149

149150
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_updates,
150151
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
@@ -196,7 +197,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
196197

197198
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
198199
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
199-
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
200+
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
200201

201202
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
202203

0 commit comments

Comments
 (0)