@@ -53,7 +53,15 @@ def __init__(self):
5353
5454 # TODO 他のスクリプトと共通化する
5555 def generate_step_logs (
56- self , args : argparse .Namespace , current_loss , avr_loss , lr_scheduler , keys_scaled = None , mean_norm = None , maximum_norm = None
56+ self ,
57+ args : argparse .Namespace ,
58+ current_loss ,
59+ avr_loss ,
60+ lr_scheduler ,
61+ lr_descriptions ,
62+ keys_scaled = None ,
63+ mean_norm = None ,
64+ maximum_norm = None ,
5765 ):
5866 logs = {"loss/current" : current_loss , "loss/average" : avr_loss }
5967
@@ -63,68 +71,25 @@ def generate_step_logs(
6371 logs ["max_norm/max_key_norm" ] = maximum_norm
6472
6573 lrs = lr_scheduler .get_last_lr ()
66-
67- if len (lrs ) > 4 :
68- idx = 0
69- if not args .network_train_unet_only :
70- logs ["lr/textencoder" ] = float (lrs [0 ])
71- idx = 1
72-
73- for i in range (idx , len (lrs )):
74- lora_plus = ""
75- group_id = i
76-
77- if args .loraplus_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
78- lora_plus = '_lora+' if i % 2 == 1 else ''
79- group_id = int ((i / 2 ) + (i % 2 + 0.5 ))
80-
81- logs [f"lr/group{ group_id } { lora_plus } " ] = float (lrs [i ])
82- if args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ():
83- logs [f"lr/d*lr/group{ group_id } { lora_plus } " ] = (
84- lr_scheduler .optimizers [- 1 ].param_groups [i ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [i ]["lr" ]
85- )
86-
87- else :
88- if args .network_train_text_encoder_only :
89- if args .loraplus_lr_ratio is not None or args .loraplus_text_encoder_lr_ratio is not None :
90- logs ["lr/textencoder" ] = float (lrs [0 ])
91- logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
92- else :
93- logs ["lr/textencoder" ] = float (lrs [0 ])
94-
95- elif args .network_train_unet_only :
96- if args .loraplus_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
97- logs ["lr/unet" ] = float (lrs [0 ])
98- logs ["lr/unet_lora+" ] = float (lrs [1 ])
99- else :
100- logs ["lr/unet" ] = float (lrs [0 ])
74+ for i , lr in enumerate (lrs ):
75+ if lr_descriptions is not None :
76+ lr_desc = lr_descriptions [i ]
10177 else :
102- if len (lrs ) == 2 :
103- if args .loraplus_text_encoder_lr_ratio is not None and args .loraplus_unet_lr_ratio is None :
104- logs ["lr/textencoder" ] = float (lrs [0 ])
105- logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
106- elif args .loraplus_unet_lr_ratio is not None and args .loraplus_text_encoder_lr_ratio is None :
107- logs ["lr/unet" ] = float (lrs [0 ])
108- logs ["lr/unet_lora+" ] = float (lrs [1 ])
109- elif args .loraplus_unet_lr_ratio is None and args .loraplus_text_encoder_lr_ratio is None and args .loraplus_lr_ratio is not None :
110- logs ["lr/all" ] = float (lrs [0 ])
111- logs ["lr/all_lora+" ] = float (lrs [1 ])
112- else :
113- logs ["lr/textencoder" ] = float (lrs [0 ])
114- logs ["lr/unet" ] = float (lrs [- 1 ])
115- elif len (lrs ) == 4 :
116- logs ["lr/textencoder" ] = float (lrs [0 ])
117- logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
118- logs ["lr/unet" ] = float (lrs [2 ])
119- logs ["lr/unet_lora+" ] = float (lrs [3 ])
78+ idx = i - (0 if args .network_train_unet_only else - 1 )
79+ if idx == - 1 :
80+ lr_desc = "textencoder"
12081 else :
121- logs ["lr/all" ] = float (lrs [0 ])
82+ if len (lrs ) > 2 :
83+ lr_desc = f"group{ idx } "
84+ else :
85+ lr_desc = "unet"
86+
87+ logs [f"lr/{ lr_desc } " ] = lr
12288
123- if (
124- args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ()
125- ): # tracking d*lr value of unet.
126- logs ["lr/d*lr" ] = (
127- lr_scheduler .optimizers [- 1 ].param_groups [0 ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [0 ]["lr" ]
89+ if args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ():
90+ # tracking d*lr value
91+ logs [f"lr/d*lr/{ lr_desc } " ] = (
92+ lr_scheduler .optimizers [- 1 ].param_groups [i ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [i ]["lr" ]
12893 )
12994
13095 return logs
@@ -358,6 +323,7 @@ def train(self, args):
358323 network .apply_to (text_encoder , unet , train_text_encoder , train_unet )
359324
360325 if args .network_weights is not None :
326+ # FIXME consider alpha of weights
361327 info = network .load_weights (args .network_weights )
362328 accelerator .print (f"load network weights from { args .network_weights } : { info } " )
363329
@@ -373,20 +339,23 @@ def train(self, args):
373339
374340 # 後方互換性を確保するよ
375341 try :
376- trainable_params = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr , args .learning_rate , args .loraplus_text_encoder_lr_ratio , args .loraplus_unet_lr_ratio , args .loraplus_lr_ratio )
342+ results = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr , args .learning_rate )
343+ if type (results ) is tuple :
344+ trainable_params = results [0 ]
345+ lr_descriptions = results [1 ]
346+ else :
347+ trainable_params = results
348+ lr_descriptions = None
377349 except TypeError :
378- accelerator .print (
379- "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
380- )
350+ # accelerator.print(
351+ # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
352+ # )
381353 trainable_params = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr )
354+ lr_descriptions = None
355+ print (lr_descriptions )
382356
383357 optimizer_name , optimizer_args , optimizer = train_util .get_optimizer (args , trainable_params )
384358
385- if args .loraplus_lr_ratio is not None or args .loraplus_text_encoder_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
386- assert (
387- (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name )
388- ), "LoRA+ and Prodigy/DAdaptation is not supported"
389-
390359 # dataloaderを準備する
391360 # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
392361 n_workers = min (args .max_data_loader_n_workers , os .cpu_count ()) # cpu_count or max_data_loader_n_workers
@@ -992,7 +961,9 @@ def remove_model(old_ckpt_name):
992961 progress_bar .set_postfix (** {** max_mean_logs , ** logs })
993962
994963 if args .logging_dir is not None :
995- logs = self .generate_step_logs (args , current_loss , avr_loss , lr_scheduler , keys_scaled , mean_norm , maximum_norm )
964+ logs = self .generate_step_logs (
965+ args , current_loss , avr_loss , lr_scheduler , lr_descriptions , keys_scaled , mean_norm , maximum_norm
966+ )
996967 accelerator .log (logs , step = global_step )
997968
998969 if global_step >= args .max_train_steps :
@@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser:
11431114 action = "store_true" ,
11441115 help = "do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う" ,
11451116 )
1117+ # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
1118+ # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
1119+ # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
11461120 return parser
11471121
11481122
0 commit comments