Skip to content

Commit 6f3f40f

Browse files
Sethu Sankaranfacebook-github-bot
authored andcommitted
ReFNet changes: Use encoders for refiner pooler inputs and fix edge case for MS loss calculation when no negative pairs found. (#1160)
Summary: Pull Request resolved: #1160 Loss value dict needs to be initialized and set to default values when no negatives per-batch (for MS loss calculation). Also, the refiner class should take in encoded layers rather than sequence output. Reviewed By: ebsmothers Differential Revision: D32696676 fbshipit-source-id: b52d8532445141499152353c7893fb83ef6142c4
1 parent cd6f58f commit 6f3f40f

File tree

4 files changed

+26
-15
lines changed

4 files changed

+26
-15
lines changed

mmf/models/transformers/heads/refiner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ def forward(
120120
end_token[modality] = start_token[modality] + sz[1] - 1
121121
prev_end_token = end_token[modality] + 1
122122

123-
attention_mask = torch.cat(masks, dim=1)
123+
pad_mask = torch.cat(masks, dim=1)
124124
processed_sample_list["refiner_outputs"] = {}
125125
processed_sample_list["refiner_outputs"]["fused_embedding"] = self.pooler(
126-
sequence_output, attention_mask
126+
encoded_layers, pad_mask
127127
)
128128
processed_sample_list["refiner_targets"] = {}
129129
for modality in self.modalities:
@@ -132,7 +132,7 @@ def forward(
132132
tk_end = end_token[modality]
133133
for enc_layers in encoded_layers:
134134
modality_encodings.append(enc_layers[:, tk_start : tk_end + 1, :])
135-
modality_mask_encodings = attention_mask[:, tk_start : tk_end + 1]
135+
modality_mask_encodings = pad_mask[:, tk_start : tk_end + 1]
136136
processed_sample_list["refiner_targets"][modality] = self.pooler(
137137
modality_encodings, modality_mask_encodings
138138
)
@@ -169,7 +169,7 @@ def forward(
169169
]
170170
refiner_modal_outputs = {}
171171
refiner_modal_outputs["scores"] = refiner_reconstruct[modality]
172-
loss = self.refinerloss(modality_targets, refiner_modal_outputs)
172+
loss = self.refiner_loss(modality_targets, refiner_modal_outputs)
173173

174174
else:
175175
loss = self.weights[modality] * self.refiner_loss(

mmf/models/transformers/heads/refnet_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def forward(
5858
targets_subset = {}
5959
targets_subset["targets"] = processed_sample_list["target_key"]["targets"]
6060
targets_subset["targets"] = targets_subset["targets"][:score_max]
61+
if "losses" not in output_dict.keys():
62+
output_dict["losses"] = {}
6163
output_dict["losses"][self.loss_name] = self.loss_fn(
6264
targets_subset, scores_subset
6365
)

mmf/modules/losses.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,10 @@ def forward(self, sample_list, model_output):
933933
pos_loss = calc_ms_loss(pos_pair, self.base, self.beta, -1)
934934
neg_loss = calc_ms_loss(neg_pairs, self.base, self.alpha, 1)
935935
loss.append(pos_loss + neg_loss)
936-
937-
loss = sum(loss) / n
936+
if n > 0:
937+
loss = sum(loss) / n
938+
else:
939+
loss = inputs.new_zeros(1, requires_grad=True)
938940
return loss
939941

940942

projects/mmbt/configs/mmimdb/paper_ablations_reducedlabel.yaml

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,30 @@ model_config:
3636
hidden_size: 768
3737
vocab_size: 30522
3838
loss_type: "cosine"
39+
refiner_target_pooler: "average_k_from_last"
40+
refiner_target_layer_depth: 2
3941
modalities:
4042
- "text"
4143
- "image"
4244
weights:
4345
- 0.0
4446
- 0.0
45-
mlp_config:
46-
type: mlp
47-
freeze: false
48-
num_labels: 24
49-
lr_multiplier: 1.0
50-
hidden_size: 768
51-
vocab_size: 30522
47+
mlp_loss_config:
48+
config:
49+
type: mlp
50+
num_labels: 24
51+
hidden_size: 768
52+
hidden_dropout_prob: 0.1
53+
layer_norm_eps: 0.000001
54+
hidden_act: gelu
55+
pooler_name: bert_pooler
56+
num_layers: 1
57+
loss_name: classification_loss
5258
loss: logit_bce
53-
max_sample_size: 32
54-
ms_loss_weight: 0.05
59+
max_sample_size: 33
60+
ms_loss_weight: 0.0
5561
use_msloss: true
62+
num_labels: 24
5663
self_weight_decay: 0.997
5764
dataset_config:
5865
mmimdb:

0 commit comments

Comments
 (0)