@@ -27,128 +27,67 @@ def prod(iterable):
2727if lib and lib .compiled_with_cuda :
2828 """C FUNCTIONS FOR OPTIMIZERS"""
2929 str2optimizer32bit = {
30- "adagrad" : (
31- lib .cadagrad32bit_grad_fp32 ,
32- lib .cadagrad32bit_grad_fp16 ,
33- ),
3430 "adam" : (
3531 lib .cadam32bit_grad_fp32 ,
3632 lib .cadam32bit_grad_fp16 ,
3733 lib .cadam32bit_grad_bf16 ,
3834 ),
39- "pagedadam" : (
40- lib .cpagedadam32bit_grad_fp32 ,
41- lib .cpagedadam32bit_grad_fp16 ,
42- lib .cpagedadam32bit_grad_bf16 ,
43- ),
44- "adamw" : (
45- lib .cadam32bit_grad_fp32 ,
46- lib .cadam32bit_grad_fp16 ,
47- lib .cadam32bit_grad_bf16 ,
48- ),
49- "pagedadamw" : (
50- lib .cpagedadam32bit_grad_fp32 ,
51- lib .cpagedadam32bit_grad_fp16 ,
52- lib .cpagedadam32bit_grad_bf16 ,
53- ),
54- "lamb" : (
55- lib .cadam32bit_grad_fp32 ,
56- lib .cadam32bit_grad_fp16 ,
35+ "momentum" : (
36+ lib .cmomentum32bit_grad_32 ,
37+ lib .cmomentum32bit_grad_16 ,
5738 ),
58- "lars " : (
59- lib .clars32bit_grad_fp32 ,
60- lib .clars32bit_grad_fp16 ,
39+ "rmsprop " : (
40+ lib .crmsprop32bit_grad_32 ,
41+ lib .crmsprop32bit_grad_16 ,
6142 ),
6243 "lion" : (
6344 lib .clion32bit_grad_fp32 ,
6445 lib .clion32bit_grad_fp16 ,
6546 lib .clion32bit_grad_bf16 ,
6647 ),
67- "momentum " : (
68- lib .cmomentum32bit_grad_fp32 ,
69- lib .cmomentum32bit_grad_fp16 ,
48+ "adagrad " : (
49+ lib .cadagrad32bit_grad_32 ,
50+ lib .cadagrad32bit_grad_16 ,
7051 ),
71- "rmsprop " : (
72- lib .crmsprop32bit_grad_fp32 ,
73- lib .crmsprop32bit_grad_fp16 ,
52+ "lamb " : (
53+ lib .cadam32bit_grad_fp32 ,
54+ lib .cadam32bit_grad_fp16 ,
7455 ),
7556 }
7657
7758 str2optimizer8bit = {
78- "adagrad" : (
79- lib .cadagrad8bit_grad_fp32 ,
80- lib .cadagrad8bit_grad_fp16 ,
81- ),
8259 "adam" : (
83- lib .cadam_static_8bit_grad_fp32 ,
84- lib .cadam_static_8bit_grad_fp16 ,
60+ lib .cadam_static_8bit_grad_32 ,
61+ lib .cadam_static_8bit_grad_16 ,
8562 ),
86- "pagedadam" : (
87- lib .cpagedadam8bit_grad_fp32 ,
88- lib .cpagedadam8bit_grad_fp16 ,
89- lib .cpagedadam8bit_grad_bf16 ,
63+ "momentum" : (
64+ lib .cmomentum_static_8bit_grad_32 ,
65+ lib .cmomentum_static_8bit_grad_16 ,
9066 ),
91- "adamw " : (
92- lib .cadam_static_8bit_grad_fp32 ,
93- lib .cadam_static_8bit_grad_fp16 ,
67+ "rmsprop " : (
68+ lib .crmsprop_static_8bit_grad_32 ,
69+ lib .crmsprop_static_8bit_grad_16 ,
9470 ),
95- "pagedadamw" : (
96- lib .cpagedadam8bit_grad_fp32 ,
97- lib .cpagedadam8bit_grad_fp16 ,
98- lib .cpagedadam8bit_grad_bf16 ,
71+ "lion" : (
72+ lib .clion_static_8bit_grad_32 ,
73+ lib .clion_static_8bit_grad_16 ,
9974 ),
10075 "lamb" : (
101- lib .cadam_static_8bit_grad_fp32 ,
102- lib .cadam_static_8bit_grad_fp16 ,
76+ lib .cadam_static_8bit_grad_32 ,
77+ lib .cadam_static_8bit_grad_16 ,
10378 ),
10479 "lars" : (
105- lib .clars8bit_grad_fp32 ,
106- lib .clars8bit_grad_fp16 ,
107- ),
108- "lion" : (
109- lib .clion_static_8bit_grad_fp32 ,
110- lib .clion_static_8bit_grad_fp16 ,
111- ),
112- "momentum" : (
113- lib .cmomentum_static_8bit_grad_fp32 ,
114- lib .cmomentum_static_8bit_grad_fp16 ,
115- ),
116- "rmsprop" : (
117- lib .crmsprop_static_8bit_grad_fp32 ,
118- lib .crmsprop_static_8bit_grad_fp16 ,
80+ lib .cmomentum_static_8bit_grad_32 ,
81+ lib .cmomentum_static_8bit_grad_16 ,
11982 ),
12083 }
12184
12285 str2optimizer8bit_blockwise = {
123- "adagrad" : (
124- lib .cadagrad_8bit_blockwise_grad_fp32 ,
125- lib .cadagrad_8bit_blockwise_grad_fp16 ,
126- ),
12786 "adam" : (
12887 lib .cadam_8bit_blockwise_grad_fp32 ,
12988 lib .cadam_8bit_blockwise_grad_fp16 ,
13089 lib .cadam_8bit_blockwise_grad_bf16 ,
13190 ),
132- "pagedadam" : (
133- lib .cpagedadam8bit_blockwise_fp32 ,
134- lib .cpagedadam8bit_blockwise_fp16 ,
135- lib .cpagedadam8bit_blockwise_bf16 ,
136- ),
137- "adamw" : (
138- lib .cadam_8bit_blockwise_grad_fp32 ,
139- lib .cadam_8bit_blockwise_grad_fp16 ,
140- lib .cadam_8bit_blockwise_grad_bf16 ,
141- ),
142- "pagedadamw" : (
143- lib .cpagedadam8bit_blockwise_fp32 ,
144- lib .cpagedadam8bit_blockwise_fp16 ,
145- lib .cpagedadam8bit_blockwise_bf16 ,
146- ),
147- "lion" : (
148- lib .clion_8bit_blockwise_grad_fp32 ,
149- lib .clion_8bit_blockwise_grad_fp16 ,
150- lib .clion_8bit_blockwise_grad_bf16 ,
151- ),
15291 "momentum" : (
15392 lib .cmomentum_8bit_blockwise_grad_fp32 ,
15493 lib .cmomentum_8bit_blockwise_grad_fp16 ,
@@ -157,6 +96,15 @@ def prod(iterable):
15796 lib .crmsprop_8bit_blockwise_grad_fp32 ,
15897 lib .crmsprop_8bit_blockwise_grad_fp16 ,
15998 ),
99+ "lion" : (
100+ lib .clion_8bit_blockwise_grad_fp32 ,
101+ lib .clion_8bit_blockwise_grad_fp16 ,
102+ lib .clion_8bit_blockwise_grad_bf16 ,
103+ ),
104+ "adagrad" : (
105+ lib .cadagrad_8bit_blockwise_grad_fp32 ,
106+ lib .cadagrad_8bit_blockwise_grad_fp16 ,
107+ ),
160108 }
161109
162110
0 commit comments