Skip to content

Commit 8af17eb

Browse files
Jamezo97wkpark
authored andcommitted
minimal fix Windows compilation issues
manually cherry-picked from PR #788 and cleanup Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
1 parent 726f147 commit 8af17eb

4 files changed

Lines changed: 43 additions & 10 deletions

File tree

csrc/cpu_ops.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include <BinSearch.h>
2+
#ifdef _WIN32
3+
#include <thread>
4+
#else
25
#include <pthread.h>
6+
#endif
37
#include <common.h>
48

59
using namespace BinSearch;
@@ -31,7 +35,11 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
3135
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
3236
{
3337
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
38+
#ifdef _WIN32
39+
std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks);
40+
#else
3441
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
42+
#endif
3543

3644
struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
3745

@@ -55,14 +63,23 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
5563
arg->threadidx = block_idx / blocksize;
5664
arg->blocksize = blocksize;
5765

66+
#ifdef _WIN32
67+
new (&threads[chunks_processed]) std::thread(quantize_block, arg);
68+
#else
5869
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
70+
#endif
5971
chunks_processed += 1;
6072
if(chunks_processed == valid_chunks){ break; }
6173
}
6274

6375
for (int i = 0; i < valid_chunks; i++)
76+
{
77+
#ifdef _WIN32
78+
threads[i].join();
79+
#else
6480
int err = pthread_join(threads[i], NULL);
65-
81+
#endif
82+
}
6683
free(threads);
6784
for (int i = 0; i < valid_chunks; i++)
6885
free(args[i]);

csrc/kernels.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3816,12 +3816,12 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N
38163816
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
38173817
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
38183818

3819-
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3820-
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3821-
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3822-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3823-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3824-
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3819+
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3820+
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3821+
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3822+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3823+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3824+
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
38253825

38263826
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
38273827
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);

csrc/ops.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include <stdio.h>
1111
#include <iostream>
12-
#include <unistd.h>
1312
#include <assert.h>
1413

1514
#include <cuda_runtime_api.h>

include/Type.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,30 @@ struct CondData<T,false>
201201
FORCE_INLINE operator const T() const { return 0;}
202202
};
203203

204+
#ifdef _WIN32
205+
// The `IsComplete` buildtime check doesn't work on Windows
206+
// Given the usage of the BinAlgoBase class, `I != Scalar` should be equivalent to the unix
207+
// equivalent below of `Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value`
208+
template <InstrSet I, typename T, Algos A>
209+
struct WouldAlgoVecBaseBeComplete
210+
{
211+
static constexpr bool value{I != Scalar};
212+
};
213+
#else
214+
template <InstrSet I, typename T, Algos A>
215+
struct WouldAlgoVecBaseBeComplete : public Details::IsComplete<Details::AlgoVecBase<I, T, A>>
216+
{
217+
218+
};
219+
#endif
220+
204221
template <InstrSet I, typename T, Algos A, bool L=false>
205-
struct BinAlgoBase : Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
222+
struct BinAlgoBase : Details::conditional< WouldAlgoVecBaseBeComplete<I, T, A>::value
206223
, Details::AlgoVecBase<I, T, A>
207224
, Details::AlgoScalarToVec<T,A>
208225
>::type
209226
{
210-
typedef typename Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
227+
typedef typename Details::conditional< WouldAlgoVecBaseBeComplete<I, T, A>::value
211228
, Details::AlgoVecBase<I, T, A>
212229
, Details::AlgoScalarToVec<T,A>
213230
>::type base_t;

0 commit comments

Comments
 (0)