Skip to content
49 changes: 28 additions & 21 deletions mmf/models/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,13 @@ def build(self):

def get_optimizer_parameters(self, config):
lr = config.optimizer.params.lr

backbone_param_set = set()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's name this trunk_params_set.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reviewing. I will change my code according to ur suggestion.

param_list = []
parameters = []
head_configs = self.config.get("heads", [])

for name, module in self.named_children():
# Heads can have different learning rates. This is handled here
if name == "heads":
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
)
elif name == "encoders":

if name == "encoders":
for key in module:
for modality in self.config.modalities:
if key == modality.key:
Expand All @@ -134,29 +122,48 @@ def get_optimizer_parameters(self, config):
parameters=parameters,
param_list=param_list,
)
else:
elif name != "heads":
Comment thread
butterluo marked this conversation as resolved.
Outdated
# For other modules in trunk, add to same param group
param_list += list(module.named_parameters())

backbone_param_set.update(list(module.parameters()))
head_configs = self.config.get("heads", [])
# Heads can have different learning rates. This is handled here
if len(head_configs) > 0:
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
backbone_param_set = backbone_param_set
)
parameters += get_bert_configured_parameters(param_list)

return parameters

def set_lr_for_parameters(
self, config, module_name, base_lr, module, parameters, param_list
self, config, module_name, base_lr, module, parameters, param_list, backbone_param_set = None
Comment thread
butterluo marked this conversation as resolved.
Outdated
):
lr_multiplier = config.get("lr_multiplier", 1.0)
if backbone_param_set is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is None, make it an empty list, [].

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reviewing. I will change my code according to ur suggestion.

module_param = list(module.named_parameters())
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, you can remove this else condition.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reviewing. I will change my code according to ur suggestion.

module_param = [ tup for tup in module.named_parameters() if tup[1] not in backbone_param_set ]
if lr_multiplier != 1.0:
logger.info(
f"Setting learning rate of {module_name} to be {base_lr} * {lr_multiplier}."
) # noqa
parameters += get_bert_configured_parameters(
module, base_lr * lr_multiplier
module_param, base_lr * lr_multiplier
)
else:
# Parameters for the modules with same learning rate as
# trunk, add to same param group
param_list += list(module.named_parameters())
param_list += module_param
return parameters, param_list

def build_encoders(self):
Expand Down