@@ -3829,27 +3829,33 @@ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8
38293829
38303830MAKE_PreconditionOptimizer32bit1State (MOMENTUM, half)
38313831MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float )
3832+ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16)
38323833MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
38333834MAKE_PreconditionOptimizer32bit1State(RMSPROP, float )
3835+ MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16)
38343836MAKE_PreconditionOptimizer32bit1State(LION, half)
38353837MAKE_PreconditionOptimizer32bit1State(LION, float )
38363838MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
38373839MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
38383840MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float )
3841+ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
38393842
38403843#define MAKE_Optimizer32bit1State (oname, gtype ) \
38413844template __global__ void kOptimizer32bit1State <gtype, oname>(gtype* g, gtype* p, float * state1, float *unorm, const float max_unorm, const float param_norm, \
38423845 const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
38433846
38443847MAKE_Optimizer32bit1State (MOMENTUM, half)
38453848MAKE_Optimizer32bit1State(MOMENTUM, float )
3849+ MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16)
38463850MAKE_Optimizer32bit1State(RMSPROP, half)
38473851MAKE_Optimizer32bit1State(RMSPROP, float )
3852+ MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16)
38483853MAKE_Optimizer32bit1State(LION, half)
38493854MAKE_Optimizer32bit1State(LION, float )
38503855MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
38513856MAKE_Optimizer32bit1State(ADAGRAD, half)
38523857MAKE_Optimizer32bit1State(ADAGRAD, float )
3858+ MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)
38533859
38543860#define MAKE_PreconditionOptimizer32bit2State (oname, gtype ) \
38553861template __global__ void kPreconditionOptimizer32bit2State <gtype, oname, 4096 , 8 >(gtype* g, gtype* p, \
@@ -3950,6 +3956,8 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
39503956
39513957template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
39523958template __global__ void kPercentileClipping <half, 2048 , 4 >(half * __restrict__ g, float *gnorm_vec, int step, const int n);
3959+ // template __global__ void kPercentileClipping<float, 128, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
3960+ // template __global__ void kPercentileClipping<half, 128, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
39533961
39543962#define MAKE_kQuantizeBlockwise (dtype, blocksize, num_per_thread, stochastic, data_type_name ) \
39553963template __global__ void kQuantizeBlockwise <dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
@@ -4041,13 +4049,12 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
40414049 float weight_decay, \
40424050 const float gnorm_scale, const bool skip_zeros, const int n); \
40434051
4044- MAKE_OptimizerStatic8bit2StateBlockwise (ADAM, float , 2048 , 8 )
4045- MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048 , 8 )
4046- MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048 , 8 )
4047- MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float , 2048 , 8 )
4048- MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 2048 , 8 )
4049- MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 2048 , 8 )
4050-
4052+ MAKE_OptimizerStatic8bit2StateBlockwise (ADAM, float , 256 , 1 )
4053+ MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256 , 1 )
4054+ MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256 , 1 )
4055+ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float , 256 , 1 )
4056+ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256 , 1 )
4057+ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256 , 1 )
40514058
40524059#define MAKE_OptimizerStatic8bit1StateBlockwise (oname, gtype, block_size, num_per_thread ) \
40534060template __global__ void kOptimizerStatic8bit1StateBlockwise <gtype, oname, block_size, num_per_thread>( \
@@ -4059,15 +4066,18 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
40594066 float weight_decay, \
40604067 const float gnorm_scale, const bool skip_zeros, const int n); \
40614068
4062- MAKE_OptimizerStatic8bit1StateBlockwise (MOMENTUM, float , 2048 , 8 )
4063- MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048 , 8 )
4064- MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float , 2048 , 8 )
4065- MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048 , 8 )
4066- MAKE_OptimizerStatic8bit1StateBlockwise(LION, float , 2048 , 8 )
4067- MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048 , 8 )
4068- MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048 , 8 )
4069- MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float , 2048 , 8 )
4070- MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048 , 8 )
4069+ MAKE_OptimizerStatic8bit1StateBlockwise (MOMENTUM, float , 256 , 1 )
4070+ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256 , 1 )
4071+ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256 , 1 )
4072+ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float , 256 , 1 )
4073+ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256 , 1 )
4074+ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256 , 1 )
4075+ MAKE_OptimizerStatic8bit1StateBlockwise(LION, float , 256 , 1 )
4076+ MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256 , 1 )
4077+ MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256 , 1 )
4078+ MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float , 256 , 1 )
4079+ MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256 , 1 )
4080+ MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256 , 1 )
40714081
40724082template __device__ void printnonzero<float>(float *A, int num_values, const char *strval);
40734083template __device__ void printnonzero<half>(half *A, int num_values, const char *strval);
0 commit comments