Skip to content

Commit b22ae26

Browse files
fix for faulty #1222 ("Add "lamb" to str2optimizer32bit") (#1240)
* Revert "Add `"lamb"` to `str2optimizer32bit`" * Update bitsandbytes/functional.py
1 parent 1f2ca43 commit b22ae26

1 file changed

Lines changed: 36 additions & 88 deletions

File tree

bitsandbytes/functional.py

Lines changed: 36 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -27,128 +27,67 @@ def prod(iterable):
2727
if lib and lib.compiled_with_cuda:
2828
"""C FUNCTIONS FOR OPTIMIZERS"""
2929
str2optimizer32bit = {
30-
"adagrad": (
31-
lib.cadagrad32bit_grad_fp32,
32-
lib.cadagrad32bit_grad_fp16,
33-
),
3430
"adam": (
3531
lib.cadam32bit_grad_fp32,
3632
lib.cadam32bit_grad_fp16,
3733
lib.cadam32bit_grad_bf16,
3834
),
39-
"pagedadam": (
40-
lib.cpagedadam32bit_grad_fp32,
41-
lib.cpagedadam32bit_grad_fp16,
42-
lib.cpagedadam32bit_grad_bf16,
43-
),
44-
"adamw": (
45-
lib.cadam32bit_grad_fp32,
46-
lib.cadam32bit_grad_fp16,
47-
lib.cadam32bit_grad_bf16,
48-
),
49-
"pagedadamw": (
50-
lib.cpagedadam32bit_grad_fp32,
51-
lib.cpagedadam32bit_grad_fp16,
52-
lib.cpagedadam32bit_grad_bf16,
53-
),
54-
"lamb": (
55-
lib.cadam32bit_grad_fp32,
56-
lib.cadam32bit_grad_fp16,
35+
"momentum": (
36+
lib.cmomentum32bit_grad_32,
37+
lib.cmomentum32bit_grad_16,
5738
),
58-
"lars": (
59-
lib.clars32bit_grad_fp32,
60-
lib.clars32bit_grad_fp16,
39+
"rmsprop": (
40+
lib.crmsprop32bit_grad_32,
41+
lib.crmsprop32bit_grad_16,
6142
),
6243
"lion": (
6344
lib.clion32bit_grad_fp32,
6445
lib.clion32bit_grad_fp16,
6546
lib.clion32bit_grad_bf16,
6647
),
67-
"momentum": (
68-
lib.cmomentum32bit_grad_fp32,
69-
lib.cmomentum32bit_grad_fp16,
48+
"adagrad": (
49+
lib.cadagrad32bit_grad_32,
50+
lib.cadagrad32bit_grad_16,
7051
),
71-
"rmsprop": (
72-
lib.crmsprop32bit_grad_fp32,
73-
lib.crmsprop32bit_grad_fp16,
52+
"lamb": (
53+
lib.cadam32bit_grad_fp32,
54+
lib.cadam32bit_grad_fp16,
7455
),
7556
}
7657

7758
str2optimizer8bit = {
78-
"adagrad": (
79-
lib.cadagrad8bit_grad_fp32,
80-
lib.cadagrad8bit_grad_fp16,
81-
),
8259
"adam": (
83-
lib.cadam_static_8bit_grad_fp32,
84-
lib.cadam_static_8bit_grad_fp16,
60+
lib.cadam_static_8bit_grad_32,
61+
lib.cadam_static_8bit_grad_16,
8562
),
86-
"pagedadam": (
87-
lib.cpagedadam8bit_grad_fp32,
88-
lib.cpagedadam8bit_grad_fp16,
89-
lib.cpagedadam8bit_grad_bf16,
63+
"momentum": (
64+
lib.cmomentum_static_8bit_grad_32,
65+
lib.cmomentum_static_8bit_grad_16,
9066
),
91-
"adamw": (
92-
lib.cadam_static_8bit_grad_fp32,
93-
lib.cadam_static_8bit_grad_fp16,
67+
"rmsprop": (
68+
lib.crmsprop_static_8bit_grad_32,
69+
lib.crmsprop_static_8bit_grad_16,
9470
),
95-
"pagedadamw": (
96-
lib.cpagedadam8bit_grad_fp32,
97-
lib.cpagedadam8bit_grad_fp16,
98-
lib.cpagedadam8bit_grad_bf16,
71+
"lion": (
72+
lib.clion_static_8bit_grad_32,
73+
lib.clion_static_8bit_grad_16,
9974
),
10075
"lamb": (
101-
lib.cadam_static_8bit_grad_fp32,
102-
lib.cadam_static_8bit_grad_fp16,
76+
lib.cadam_static_8bit_grad_32,
77+
lib.cadam_static_8bit_grad_16,
10378
),
10479
"lars": (
105-
lib.clars8bit_grad_fp32,
106-
lib.clars8bit_grad_fp16,
107-
),
108-
"lion": (
109-
lib.clion_static_8bit_grad_fp32,
110-
lib.clion_static_8bit_grad_fp16,
111-
),
112-
"momentum": (
113-
lib.cmomentum_static_8bit_grad_fp32,
114-
lib.cmomentum_static_8bit_grad_fp16,
115-
),
116-
"rmsprop": (
117-
lib.crmsprop_static_8bit_grad_fp32,
118-
lib.crmsprop_static_8bit_grad_fp16,
80+
lib.cmomentum_static_8bit_grad_32,
81+
lib.cmomentum_static_8bit_grad_16,
11982
),
12083
}
12184

12285
str2optimizer8bit_blockwise = {
123-
"adagrad": (
124-
lib.cadagrad_8bit_blockwise_grad_fp32,
125-
lib.cadagrad_8bit_blockwise_grad_fp16,
126-
),
12786
"adam": (
12887
lib.cadam_8bit_blockwise_grad_fp32,
12988
lib.cadam_8bit_blockwise_grad_fp16,
13089
lib.cadam_8bit_blockwise_grad_bf16,
13190
),
132-
"pagedadam": (
133-
lib.cpagedadam8bit_blockwise_fp32,
134-
lib.cpagedadam8bit_blockwise_fp16,
135-
lib.cpagedadam8bit_blockwise_bf16,
136-
),
137-
"adamw": (
138-
lib.cadam_8bit_blockwise_grad_fp32,
139-
lib.cadam_8bit_blockwise_grad_fp16,
140-
lib.cadam_8bit_blockwise_grad_bf16,
141-
),
142-
"pagedadamw": (
143-
lib.cpagedadam8bit_blockwise_fp32,
144-
lib.cpagedadam8bit_blockwise_fp16,
145-
lib.cpagedadam8bit_blockwise_bf16,
146-
),
147-
"lion": (
148-
lib.clion_8bit_blockwise_grad_fp32,
149-
lib.clion_8bit_blockwise_grad_fp16,
150-
lib.clion_8bit_blockwise_grad_bf16,
151-
),
15291
"momentum": (
15392
lib.cmomentum_8bit_blockwise_grad_fp32,
15493
lib.cmomentum_8bit_blockwise_grad_fp16,
@@ -157,6 +96,15 @@ def prod(iterable):
15796
lib.crmsprop_8bit_blockwise_grad_fp32,
15897
lib.crmsprop_8bit_blockwise_grad_fp16,
15998
),
99+
"lion": (
100+
lib.clion_8bit_blockwise_grad_fp32,
101+
lib.clion_8bit_blockwise_grad_fp16,
102+
lib.clion_8bit_blockwise_grad_bf16,
103+
),
104+
"adagrad": (
105+
lib.cadagrad_8bit_blockwise_grad_fp32,
106+
lib.cadagrad_8bit_blockwise_grad_fp16,
107+
),
160108
}
161109

162110

0 commit comments

Comments
 (0)