Skip to content

Commit aa57bd8

Browse files
Change 8bit optimizer blocksize 2048->256; additional bf16 support (#1365)
* Change 8bit optimizer blocksize 2048->256; additional bf16 support * Update tolerances for 8bit optimizer tests
1 parent d964546 commit aa57bd8

7 files changed

Lines changed: 88 additions & 52 deletions

File tree

bitsandbytes/functional.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def prod(iterable):
5252
"lamb": (
5353
lib.cadam32bit_grad_fp32,
5454
lib.cadam32bit_grad_fp16,
55+
lib.cadam32bit_grad_bf16,
5556
),
5657
"ademamix": (
5758
lib.cademamix32bit_grad_fp32,
@@ -96,10 +97,12 @@ def prod(iterable):
9697
"momentum": (
9798
lib.cmomentum_8bit_blockwise_grad_fp32,
9899
lib.cmomentum_8bit_blockwise_grad_fp16,
100+
lib.cmomentum_8bit_blockwise_grad_bf16,
99101
),
100102
"rmsprop": (
101103
lib.crmsprop_8bit_blockwise_grad_fp32,
102104
lib.crmsprop_8bit_blockwise_grad_fp16,
105+
lib.crmsprop_8bit_blockwise_grad_bf16,
103106
),
104107
"lion": (
105108
lib.clion_8bit_blockwise_grad_fp32,
@@ -109,6 +112,7 @@ def prod(iterable):
109112
"adagrad": (
110113
lib.cadagrad_8bit_blockwise_grad_fp32,
111114
lib.cadagrad_8bit_blockwise_grad_fp16,
115+
lib.cadagrad_8bit_blockwise_grad_bf16,
112116
),
113117
"ademamix": (
114118
lib.cademamix_8bit_blockwise_grad_fp32,
@@ -398,7 +402,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
398402
data.append(0)
399403

400404
data.sort()
401-
return Tensor(data)
405+
return torch.tensor(data)
402406

403407

404408
def create_quantile_map(A, total_bits=8):

bitsandbytes/optim/ademamix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def init_state(self, group, p, gindex, pindex):
166166
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)
167167

168168
n = p.numel()
169-
blocks = (n // 2048) + bool(n % 2048)
169+
blocks = (n // 256) + bool(n % 256)
170170

171171
state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
172172
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)

bitsandbytes/optim/optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,8 @@ def init_state(self, group, p, gindex, pindex):
477477

478478
if config["block_wise"]:
479479
n = p.numel()
480-
blocks = n // 2048
481-
blocks += 1 if n % 2048 > 0 else 0
480+
blocks = n // 256
481+
blocks += 1 if n % 256 > 0 else 0
482482

483483
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
484484
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
@@ -699,8 +699,8 @@ def init_state(self, group, p, gindex, pindex):
699699

700700
if config["block_wise"]:
701701
n = p.numel()
702-
blocks = n // 2048
703-
blocks += 1 if n % 2048 > 0 else 0
702+
blocks = n // 256
703+
blocks += 1 if n % 256 > 0 else 0
704704

705705
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
706706
else:

csrc/kernels.cu

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3829,27 +3829,33 @@ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8
38293829

38303830
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
38313831
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
3832+
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16)
38323833
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
38333834
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
3835+
MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16)
38343836
MAKE_PreconditionOptimizer32bit1State(LION, half)
38353837
MAKE_PreconditionOptimizer32bit1State(LION, float)
38363838
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
38373839
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
38383840
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
3841+
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
38393842

38403843
#define MAKE_Optimizer32bit1State(oname, gtype) \
38413844
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
38423845
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
38433846

38443847
MAKE_Optimizer32bit1State(MOMENTUM, half)
38453848
MAKE_Optimizer32bit1State(MOMENTUM, float)
3849+
MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16)
38463850
MAKE_Optimizer32bit1State(RMSPROP, half)
38473851
MAKE_Optimizer32bit1State(RMSPROP, float)
3852+
MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16)
38483853
MAKE_Optimizer32bit1State(LION, half)
38493854
MAKE_Optimizer32bit1State(LION, float)
38503855
MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
38513856
MAKE_Optimizer32bit1State(ADAGRAD, half)
38523857
MAKE_Optimizer32bit1State(ADAGRAD, float)
3858+
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)
38533859

38543860
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
38553861
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
@@ -3950,6 +3956,8 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
39503956

39513957
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
39523958
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
3959+
// template __global__ void kPercentileClipping<float, 128, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
3960+
// template __global__ void kPercentileClipping<half, 128, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
39533961

39543962
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
39553963
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
@@ -4041,13 +4049,12 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
40414049
float weight_decay, \
40424050
const float gnorm_scale, const bool skip_zeros, const int n); \
40434051

4044-
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
4045-
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
4046-
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)
4047-
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 2048, 8)
4048-
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 2048, 8)
4049-
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 2048, 8)
4050-
4052+
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1)
4053+
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1)
4054+
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256, 1)
4055+
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1)
4056+
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1)
4057+
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
40514058

40524059
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
40534060
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
@@ -4059,15 +4066,18 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
40594066
float weight_decay, \
40604067
const float gnorm_scale, const bool skip_zeros, const int n); \
40614068

4062-
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
4063-
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
4064-
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
4065-
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
4066-
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
4067-
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
4068-
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
4069-
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
4070-
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
4069+
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1)
4070+
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1)
4071+
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256, 1)
4072+
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1)
4073+
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1)
4074+
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256, 1)
4075+
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1)
4076+
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1)
4077+
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1)
4078+
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
4079+
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
4080+
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)
40714081

40724082
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
40734083
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);

csrc/ops.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
191191
}
192192
}
193193

194-
#define BLOCKSIZE_2STATE 2048
195-
#define NUM_2STATE 8
196-
#define BLOCKSIZE_1STATE 2048
197-
#define NUM_1STATE 8
194+
#define BLOCKSIZE_2STATE 256
195+
#define NUM_2STATE 1
196+
#define BLOCKSIZE_1STATE 256
197+
#define NUM_1STATE 1
198198

199199
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
200200
T* p,
@@ -818,13 +818,16 @@ MAKE_optimizer32bit(ADAM, float)
818818
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
819819
MAKE_optimizer32bit(MOMENTUM, half)
820820
MAKE_optimizer32bit(MOMENTUM, float)
821+
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16)
821822
MAKE_optimizer32bit(RMSPROP, half)
822823
MAKE_optimizer32bit(RMSPROP, float)
824+
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16)
823825
MAKE_optimizer32bit(LION, half)
824826
MAKE_optimizer32bit(LION, float)
825827
MAKE_optimizer32bit(LION, __nv_bfloat16)
826828
MAKE_optimizer32bit(ADAGRAD, half)
827829
MAKE_optimizer32bit(ADAGRAD, float)
830+
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16)
828831
MAKE_optimizer32bit(ADEMAMIX, half)
829832
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
830833
MAKE_optimizer32bit(ADEMAMIX, float)
@@ -861,13 +864,16 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAM);
861864
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
862865
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
863866
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
867+
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, MOMENTUM);
864868
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
865869
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
870+
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, RMSPROP);
866871
MAKE_optimizerStatic8bitBlockwise(half, LION);
867872
MAKE_optimizerStatic8bitBlockwise(float, LION);
868873
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
869874
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
870875
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
876+
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAGRAD);
871877
MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
872878
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
873879
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);

csrc/pythonInterface.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,22 @@ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
103103
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
104104

105105
MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
106+
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
106107
MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
107108
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
109+
MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
108110
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
109111
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
112+
MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
110113
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
111114
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
115+
MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
112116
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
113-
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
114117
MAKE_BLOCKWISE8(lion, LION, half, fp16)
115-
MAKE_BLOCKWISE8(lion, LION, float, fp32)
116118
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
117-
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
119+
MAKE_BLOCKWISE8(lion, LION, float, fp32)
118120
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
121+
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
119122
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
120123

121124

@@ -283,13 +286,16 @@ extern "C"
283286

284287
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
285288
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
289+
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
286290
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
287291
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
292+
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
288293
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
289294
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
295+
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
290296
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
291297
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
292-
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
298+
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
293299
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
294300
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
295301
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)

tests/test_optim.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,18 @@ def rm_path(path):
7474
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
7575
lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
7676
)
77+
str2optimizers["paged_ademamix_scheduled"] = (
78+
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
79+
lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
80+
)
7781
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
7882
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
7983
lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
8084
)
85+
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
86+
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
87+
lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
88+
)
8189

8290
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
8391
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
@@ -143,7 +151,7 @@ def rm_path(path):
143151
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
144152

145153
str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
146-
str2statenames["paged_ademamix"] = [("m1_m2", "state1"), ("nu", "state2")]
154+
str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
147155
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
148156
("m1_m2", "state1", "qmap1", "absmax1"),
149157
("nu", "state2", "qmap2", "absmax2"),
@@ -164,6 +172,7 @@ def rm_path(path):
164172
"ademamix",
165173
"ademamix_scheduled",
166174
"paged_ademamix",
175+
"paged_ademamix_scheduled",
167176
]
168177

169178

@@ -309,18 +318,15 @@ def test_global_config(dim1, dim2, gtype):
309318
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
310319
torch.set_printoptions(precision=6)
311320

312-
if gtype == torch.bfloat16 and optim_name not in [
313-
"adam8bit_blockwise",
314-
"lion8bit_blockwise",
315-
"ademamix8bit_blockwise",
316-
]:
321+
if gtype == torch.bfloat16 and "blockwise" not in optim_name:
317322
pytest.skip()
323+
318324
if dim1 == 1 and dim2 == 1:
319325
return
320326
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
321327
p2 = p1.clone()
322328
p1 = p1.float()
323-
blocksize = 2048
329+
blocksize = 256
324330

325331
torch_optimizer = str2optimizers[optim_name][0]([p1])
326332
bnb_optimizer = str2optimizers[optim_name][1]([p2])
@@ -347,8 +353,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
347353
torch_optimizer.step()
348354

349355
# since Lion can have pretty noisy updates where things lie at the boundary
350-
# and AdEMAMix can diverge as well, allow up to 0.05% errors.
351-
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-4))
356+
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
352357

353358
dequant_states = []
354359
for name1, name2, qmap, max_val in str2statenames[optim_name]:
@@ -392,11 +397,11 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
392397
err = torch.abs(p1 - p2)
393398
relerr = err / (torch.abs(p1) + 1e-9)
394399
if g.dtype == torch.bfloat16:
395-
assert err.mean() < 0.00015
396-
assert relerr.mean() < 0.0020 # 0.0016
400+
assert err.mean() <= 0.00017
401+
assert relerr.mean() <= 0.0016
397402
else:
398-
assert err.mean() < 0.00016 # 0.00012
399-
assert relerr.mean() < 0.0016 # 0.0012
403+
assert err.mean() < 0.00006
404+
assert relerr.mean() < 0.0006
400405

401406
errors.append(err.mean().item())
402407
relerrors.append(relerr.mean().item())
@@ -454,9 +459,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
454459

455460
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
456461
assert num_not_close.sum().item() < 20
457-
# since Lion can have pretty noisy updates where things lie at the boundary
458-
# and AdEMAMix can also be noisy, allow up to 0.05%.
459-
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-04))
462+
463+
# Lion can have pretty noisy updates where things lie at the boundary
464+
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
460465

461466
# the parameters diverge quickly. Here we keep them close
462467
# together so we can test against the Adam error
@@ -560,15 +565,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
560565
optimizer_names_benchmark = [
561566
"adam8bit_blockwise",
562567
"paged_adam8bit_blockwise",
563-
"paged_adamw8bit_blockwise",
568+
"ademamix8bit_blockwise",
569+
"paged_ademamix8bit_blockwise",
570+
"ademamix8bit_blockwise_scheduled",
571+
"paged_ademamix8bit_blockwise_scheduled",
572+
"lion8bit_blockwise",
564573
"paged_lion8bit_blockwise",
565574
"paged_ademamix8bit_blockwise",
566575
]
567576

568577

569578
@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
570579
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
571-
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
580+
@pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype)
572581
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
573582
@pytest.mark.benchmark
574583
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
@@ -580,8 +589,9 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
580589

581590
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
582591
p1.grad = g
583-
for i in range(k):
584-
if i == k // 5:
592+
total_steps = 500
593+
for i in range(total_steps):
594+
if i == total_steps // 5:
585595
# 100 iterations for burn-in
586596
torch.cuda.synchronize()
587597
t0 = time.time()
@@ -591,8 +601,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
591601
torch.cuda.synchronize()
592602
s = time.time() - t0
593603
print("")
594-
params = (k - k // 5) * dim1 * dim2
595-
print(optim_name, gtype, s / params)
604+
params = (total_steps - total_steps // 5) * dim1 * dim2
605+
print(optim_name, gtype, s, params, s / params)
596606
# assert s < 3.9
597607

598608

0 commit comments

Comments
 (0)