11#include < BinSearch.h>
2- #ifdef _WIN32
3- #include < thread>
4- #else
5- #include < pthread.h>
6- #endif
72#include < common.h>
3+ #include < thread>
84
95using namespace BinSearch ;
106
@@ -30,61 +26,38 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
3026 BinAlgo<Scalar, float , Direct2> bin_searcher (code, elements_code);
3127
3228 int thread_wave_size = 256 ;
33- // we chunk the thresds into waves of 256 since the max limit is
29+ // we chunk the threads into waves of 256 since the max limit is
3430 // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
3531 for (long long offset = 0 ; offset < num_blocks; offset+=thread_wave_size)
3632 {
3733 long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
38- #ifdef _WIN32
39- std::thread *threads = (std::thread *) malloc (sizeof (std::thread) * valid_chunks);
40- #else
41- pthread_t *threads = (pthread_t *) malloc (sizeof (pthread_t ) * valid_chunks);
42- #endif
43-
44- struct quantize_block_args **args = (quantize_block_args **) malloc (valid_chunks * sizeof (quantize_block_args *));
45-
46- for (long long i = 0 ; i < valid_chunks; i++)
47- args[i] = (quantize_block_args *) malloc (sizeof (quantize_block_args));
34+ std::vector<std::thread> threads (valid_chunks);
35+ std::vector<quantize_block_args> args (valid_chunks);
4836
4937 int chunks_processed = 0 ;
5038 for (long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
5139 {
5240 long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
5341 long long block_end = block_idx + valid_items;
5442
55- struct quantize_block_args *arg = args[chunks_processed];
56- arg->bin_searcher = &bin_searcher;
57- arg->code = code;
58- arg->A = A;
59- arg->absmax = absmax;
60- arg->out = out;
61- arg->block_end = block_end;
62- arg->block_idx = block_idx;
63- arg->threadidx = block_idx / blocksize;
64- arg->blocksize = blocksize;
65-
66- #ifdef _WIN32
67- new (&threads[chunks_processed]) std::thread (quantize_block, arg);
68- #else
69- pthread_create (&threads[chunks_processed], NULL , &quantize_block, (void *) arg);
70- #endif
43+ struct quantize_block_args & arg = args[chunks_processed];
44+ arg.bin_searcher = &bin_searcher;
45+ arg.code = code;
46+ arg.A = A;
47+ arg.absmax = absmax;
48+ arg.out = out;
49+ arg.block_end = block_end;
50+ arg.block_idx = block_idx;
51+ arg.threadidx = block_idx / blocksize;
52+ arg.blocksize = blocksize;
53+
54+ threads[chunks_processed] = std::thread ([arg] { quantize_block (arg); });
7155 chunks_processed += 1 ;
7256 if (chunks_processed == valid_chunks){ break ; }
7357 }
7458
7559 for (int i = 0 ; i < valid_chunks; i++)
76- {
77- #ifdef _WIN32
7860 threads[i].join ();
79- #else
80- int err = pthread_join (threads[i], NULL );
81- #endif
82- }
83- free (threads);
84- for (int i = 0 ; i < valid_chunks; i++)
85- free (args[i]);
86- free (args);
87-
8861 }
8962
9063}
0 commit comments