@@ -66,34 +66,69 @@ def generate_step_logs(
6666
6767 lrs = lr_scheduler .get_last_lr ()
6868
69- if args .network_train_text_encoder_only or len (lrs ) <= 2 : # not block lr (or single block)
70- if args .network_train_unet_only :
71- logs ["lr/unet" ] = float (lrs [0 ])
72- elif args .network_train_text_encoder_only :
73- logs ["lr/textencoder" ] = float (lrs [0 ])
74- else :
75- logs ["lr/textencoder" ] = float (lrs [0 ])
76- logs ["lr/unet" ] = float (lrs [- 1 ]) # may be same to textencoder
77-
78- if (
79- args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ()
80- ): # tracking d*lr value of unet.
81- logs ["lr/d*lr" ] = (
82- lr_scheduler .optimizers [- 1 ].param_groups [0 ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [0 ]["lr" ]
83- )
84- else :
69+ if len (lrs ) > 4 :
8570 idx = 0
8671 if not args .network_train_unet_only :
8772 logs ["lr/textencoder" ] = float (lrs [0 ])
8873 idx = 1
8974
9075 for i in range (idx , len (lrs )):
91- logs [f"lr/group{ i } " ] = float (lrs [i ])
76+ lora_plus = ""
77+ group_id = i
78+
79+ if args .loraplus_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
80+ lora_plus = '_lora+' if i % 2 == 1 else ''
81+ group_id = int ((i / 2 ) + (i % 2 + 0.5 ))
82+
83+ logs [f"lr/group{ group_id } { lora_plus } " ] = float (lrs [i ])
9284 if args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ():
93- logs [f"lr/d*lr/group{ i } " ] = (
85+ logs [f"lr/d*lr/group{ group_id } { lora_plus } " ] = (
9486 lr_scheduler .optimizers [- 1 ].param_groups [i ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [i ]["lr" ]
9587 )
9688
89+ else :
90+ if args .network_train_text_encoder_only :
91+ if args .loraplus_lr_ratio is not None or args .loraplus_text_encoder_lr_ratio is not None :
92+ logs ["lr/textencoder" ] = float (lrs [0 ])
93+ logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
94+ else :
95+ logs ["lr/textencoder" ] = float (lrs [0 ])
96+
97+ elif args .network_train_unet_only :
98+ if args .loraplus_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
99+ logs ["lr/unet" ] = float (lrs [0 ])
100+ logs ["lr/unet_lora+" ] = float (lrs [1 ])
101+ else :
102+ logs ["lr/unet" ] = float (lrs [0 ])
103+ else :
104+ if len (lrs ) == 2 :
105+ if args .loraplus_text_encoder_lr_ratio is not None and args .loraplus_unet_lr_ratio is None :
106+ logs ["lr/textencoder" ] = float (lrs [0 ])
107+ logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
108+ elif args .loraplus_unet_lr_ratio is not None and args .loraplus_text_encoder_lr_ratio is None :
109+ logs ["lr/unet" ] = float (lrs [0 ])
110+ logs ["lr/unet_lora+" ] = float (lrs [1 ])
111+ elif args .loraplus_unet_lr_ratio is None and args .loraplus_text_encoder_lr_ratio is None and args .loraplus_lr_ratio is not None :
112+ logs ["lr/all" ] = float (lrs [0 ])
113+ logs ["lr/all_lora+" ] = float (lrs [1 ])
114+ else :
115+ logs ["lr/textencoder" ] = float (lrs [0 ])
116+ logs ["lr/unet" ] = float (lrs [- 1 ])
117+ elif len (lrs ) == 4 :
118+ logs ["lr/textencoder" ] = float (lrs [0 ])
119+ logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
120+ logs ["lr/unet" ] = float (lrs [2 ])
121+ logs ["lr/unet_lora+" ] = float (lrs [3 ])
122+ else :
123+ logs ["lr/all" ] = float (lrs [0 ])
124+
125+ if (
126+ args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ()
127+ ): # tracking d*lr value of unet.
128+ logs ["lr/d*lr" ] = (
129+ lr_scheduler .optimizers [- 1 ].param_groups [0 ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [0 ]["lr" ]
130+ )
131+
97132 return logs
98133
99134 def assert_extra_args (self , args , train_dataset_group ):
@@ -339,7 +374,7 @@ def train(self, args):
339374
340375 # 後方互換性を確保するよ
341376 try :
342- trainable_params = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr , args .learning_rate , args .loraplus_text_encoder_lr_ratio , args .loraplus_unet_lr_ratio )
377+ trainable_params = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr , args .learning_rate , args .loraplus_text_encoder_lr_ratio , args .loraplus_unet_lr_ratio , args . loraplus_lr_ratio )
343378 except TypeError :
344379 accelerator .print (
345380 "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
@@ -348,6 +383,11 @@ def train(self, args):
348383
349384 optimizer_name , optimizer_args , optimizer = train_util .get_optimizer (args , trainable_params )
350385
386+ if args .loraplus_lr_ratio is not None or args .loraplus_text_encoder_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
387+ assert (
388+ (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name )
389+ ), "LoRA+ and Prodigy/DAdaptation is not supported"
390+
351391 # dataloaderを準備する
352392 # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
353393 n_workers = min (args .max_data_loader_n_workers , os .cpu_count ()) # cpu_count or max_data_loader_n_workers
0 commit comments