@@ -1033,22 +1033,54 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331033 return lr_weight
10341034
10351035 # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036- def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr ):
1036+ def prepare_optimizer_params (
1037+ self ,
1038+ text_encoder_lr ,
1039+ unet_lr ,
1040+ default_lr ,
1041+ text_encoder_loraplus_ratio = None ,
1042+ unet_loraplus_ratio = None ,
1043+ loraplus_ratio = None
1044+ ):
10371045 self .requires_grad_ (True )
10381046 all_params = []
10391047
1040- def enumerate_params (loras : List [ LoRAModule ] ):
1041- params = []
1048+ def assemble_params (loras , lr , ratio ):
1049+ param_groups = { "lora" : {}, "plus" : {}}
10421050 for lora in loras :
1043- # params.extend(lora.parameters())
1044- params .extend (lora .get_trainable_params ())
1051+ for name , param in lora .named_parameters ():
1052+ if ratio is not None and "lora_up" in name :
1053+ param_groups ["plus" ][f"{ lora .lora_name } .{ name } " ] = param
1054+ else :
1055+ param_groups ["lora" ][f"{ lora .lora_name } .{ name } " ] = param
1056+
1057+ params = []
1058+ for key in param_groups .keys ():
1059+ param_data = {"params" : param_groups [key ].values ()}
1060+
1061+ if len (param_data ["params" ]) == 0 :
1062+ continue
1063+
1064+ if lr is not None :
1065+ if key == "plus" :
1066+ param_data ["lr" ] = lr * ratio
1067+ else :
1068+ param_data ["lr" ] = lr
1069+
1070+ if param_data .get ("lr" , None ) == 0 or param_data .get ("lr" , None ) is None :
1071+ continue
1072+
1073+ params .append (param_data )
1074+
10451075 return params
10461076
10471077 if self .text_encoder_loras :
1048- param_data = {"params" : enumerate_params (self .text_encoder_loras )}
1049- if text_encoder_lr is not None :
1050- param_data ["lr" ] = text_encoder_lr
1051- all_params .append (param_data )
1078+ params = assemble_params (
1079+ self .text_encoder_loras ,
1080+ text_encoder_lr if text_encoder_lr is not None else default_lr ,
1081+ text_encoder_loraplus_ratio or loraplus_ratio
1082+ )
1083+ all_params .extend (params )
10521084
10531085 if self .unet_loras :
10541086 if self .block_lr :
@@ -1062,21 +1094,20 @@ def enumerate_params(loras: List[LoRAModule]):
10621094
10631095 # blockごとにパラメータを設定する
10641096 for idx , block_loras in block_idx_to_lora .items ():
1065- param_data = {"params" : enumerate_params (block_loras )}
1066-
1067- if unet_lr is not None :
1068- param_data ["lr" ] = unet_lr * self .get_lr_weight (block_loras [0 ])
1069- elif default_lr is not None :
1070- param_data ["lr" ] = default_lr * self .get_lr_weight (block_loras [0 ])
1071- if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1072- continue
1073- all_params .append (param_data )
1097+ params = assemble_params (
1098+ block_loras ,
1099+ (unet_lr if unet_lr is not None else default_lr ) * self .get_lr_weight (block_loras [0 ]),
1100+ unet_loraplus_ratio or loraplus_ratio
1101+ )
1102+ all_params .extend (params )
10741103
10751104 else :
1076- param_data = {"params" : enumerate_params (self .unet_loras )}
1077- if unet_lr is not None :
1078- param_data ["lr" ] = unet_lr
1079- all_params .append (param_data )
1105+ params = assemble_params (
1106+ self .unet_loras ,
1107+ unet_lr if unet_lr is not None else default_lr ,
1108+ unet_loraplus_ratio or loraplus_ratio
1109+ )
1110+ all_params .extend (params )
10801111
10811112 return all_params
10821113
@@ -1093,6 +1124,9 @@ def on_epoch_start(self, text_encoder, unet):
10931124 def get_trainable_params (self ):
10941125 return self .parameters ()
10951126
1127+ def get_trainable_named_params (self ):
1128+ return self .named_parameters ()
1129+
10961130 def save_weights (self , file , dtype , metadata ):
10971131 if metadata is not None and len (metadata ) == 0 :
10981132 metadata = None
0 commit comments