Skip to content

Commit a51b977

Browse files
[feat] Add VinVL classification and pretraining models (#1150)
Summary: Pull Request resolved: #1150 Add VinVL classification and pretraining models that use the VinVL BertImgModel trunk. These are nn.Module objects, usable outside of MMF. For example text preprocessing for pretraining, take a look at VinVLTextTokenizer in a later diff. Models forward returns dict with scores and losses. For example usage consult the unit tests or VinVL basemodel. Test Plan: ### Unit Tests Tested forward passes for classification and pretraining models. Pretraining model forward was tested in end-to-end on winoground dataset. Reviewed By: apsdehal Differential Revision: D32574735 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: fc1a58db421a33d941b1ddbb5f5a3f35e308e741
1 parent 8e67391 commit a51b977

File tree

2 files changed

+338
-1
lines changed

2 files changed

+338
-1
lines changed

mmf/models/vinvl.py

Lines changed: 280 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@
66

77
import logging
88
from collections import namedtuple
9-
from typing import Optional, Tuple
9+
from dataclasses import asdict
10+
from typing import Dict, Optional, Tuple
1011

1112
import torch
13+
from mmf.common.sample import SampleList
14+
from mmf.models.transformers.heads.contrastive import ThreeWayContrastive
15+
from mmf.models.transformers.heads.mlm import MLM
16+
from mmf.models.transformers.heads.mlp import MLP
17+
from mmf.utils.general import retry_n
1218
from torch import Tensor, nn
1319
from transformers.modeling_bert import (
1420
BertConfig,
@@ -19,6 +25,8 @@
1925

2026
logger = logging.getLogger(__name__)
2127

28+
NUM_RETRIES = 6
29+
2230

2331
class VinVLBase(BertPreTrainedModel):
2432
"""VinVL Bert Encoder for image features
@@ -99,3 +107,274 @@ def forward(
99107
)
100108
layers = namedtuple("TransformerOutput", ["last_hidden_state", "hidden_layers"])
101109
return layers(encoder_outputs[0], encoder_outputs[1])
110+
111+
112+
def build_vinvl_base(
113+
bert_model_name: str = "bert-base-uncased",
114+
img_feature_dim: int = 2054,
115+
use_img_layernorm: bool = True,
116+
img_layer_norm_eps: float = 1e-12,
117+
random_init: bool = True,
118+
) -> VinVLBase:
119+
bert_config = retry_n(
120+
NUM_RETRIES,
121+
BertConfig.from_pretrained,
122+
bert_model_name,
123+
)
124+
# augment hf BertConfig for vinvl BertImgModel config
125+
bert_config.img_feature_dim = img_feature_dim
126+
bert_config.use_img_layernorm = use_img_layernorm
127+
bert_config.img_layer_norm_eps = img_layer_norm_eps
128+
129+
if random_init:
130+
bert = VinVLBase(bert_config)
131+
else:
132+
bert = retry_n(
133+
NUM_RETRIES,
134+
VinVLBase.from_pretrained,
135+
bert_model_name,
136+
config=bert_config,
137+
)
138+
return bert
139+
140+
141+
class VinVLForClassification(nn.Module):
142+
"""VINVL wrapper for classification"""
143+
144+
def __init__(
145+
self,
146+
mlp_config: Optional[Dict] = None,
147+
loss_config: Optional[Dict] = None,
148+
random_init: bool = False,
149+
bert_model_name: str = "bert-base-uncased",
150+
img_feature_dim: int = 2054,
151+
use_img_layernorm: bool = True,
152+
img_layer_norm_eps: float = 1e-12,
153+
*args,
154+
**kwargs,
155+
):
156+
"""VinVL model constructor for classification.
157+
MLP head is configurable through Dict type.
158+
Consult the MLP head class for the config options.
159+
160+
Args:
161+
mlp_config (Optional[Dict], optional):
162+
Classifier MLP head config.
163+
Defaults to {"num_layers": 0}.
164+
loss_config (Optional[Dict], optional):
165+
nn.CrossEntropyLoss params dict.
166+
Defaults to {}.
167+
random_init (bool, optional):
168+
Flag to load VinVL bert weights from random_init.
169+
Defaults to False.
170+
bert_model_name (str, optional):
171+
Name for base bert model.
172+
Used for VinVL base configs and weights.
173+
Defaults to "bert-base-uncased".
174+
img_feature_dim (int, optional):
175+
The size of the VinVL image feature inputs.
176+
Defaults to 2054.
177+
use_img_layernorm (bool, optional):
178+
Flag to use layernorm on image encoding.
179+
Defaults to True.
180+
img_layer_norm_eps (float, optional):
181+
Image layernorm epsilon. Defaults to 1e-12.
182+
"""
183+
super().__init__()
184+
if mlp_config is None:
185+
mlp_config = {"num_layers": 0}
186+
if loss_config is None:
187+
loss_config = {}
188+
189+
self.bert = build_vinvl_base(
190+
bert_model_name=bert_model_name,
191+
img_feature_dim=img_feature_dim,
192+
use_img_layernorm=use_img_layernorm,
193+
img_layer_norm_eps=img_layer_norm_eps,
194+
random_init=random_init,
195+
)
196+
self.classifier = MLP(config=mlp_config)
197+
self.ce_loss = nn.CrossEntropyLoss(**loss_config)
198+
199+
def forward(
200+
self,
201+
input_ids: Tensor,
202+
token_type_ids: Tensor,
203+
attention_mask: Tensor,
204+
img_feats: Tensor,
205+
position_ids: Optional[Tensor] = None,
206+
labels: Optional[Tensor] = None,
207+
) -> Dict[str, Tensor]:
208+
sequence_output = self.bert(
209+
input_ids,
210+
img_feats=img_feats,
211+
position_ids=position_ids,
212+
token_type_ids=token_type_ids,
213+
attention_mask=attention_mask,
214+
).last_hidden_state
215+
logits = self.classifier(sequence_output)["scores"]
216+
result = {"scores": logits}
217+
218+
if labels is not None:
219+
ce_loss = self.ce_loss(logits.view(-1, logits.size(1)), labels.view(-1))
220+
result["losses"] = {"ce": ce_loss}
221+
return result
222+
223+
224+
class VinVLForPretraining(nn.Module):
225+
"""VINVL wrapper for pretraining
226+
MLM loss is described in https://arxiv.org/pdf/2004.06165.pdf
227+
Contrastive loss is an itm loss to guess,
228+
0 for a match,
229+
1 for a corrupt caption,
230+
2 for corrupt image labels
231+
VinVL trains with object detection labels concatenated with the input text.
232+
"""
233+
234+
def __init__(
235+
self,
236+
mlm_config: Optional[MLM.Config] = None,
237+
contrast_config: Optional[ThreeWayContrastive.Config] = None,
238+
random_init: bool = False,
239+
bert_model_name: str = "bert-base-uncased",
240+
img_feature_dim: int = 2054,
241+
use_img_layernorm: bool = True,
242+
img_layer_norm_eps: float = 1e-12,
243+
*args,
244+
**kwargs,
245+
):
246+
"""VinVL model constructor for pretraining.
247+
MLM and Contrastive Loss heads are configurable through Dict types.
248+
Consult MLM and MLP head classes for their config options.
249+
250+
Args:
251+
mlm_config (Optional[MLM.Config], optional):
252+
Config object for MLM head.
253+
Defaults to MLM.Config which uses the default MLM configs.
254+
contrast_config (Optional[ThreeWayContrastive.Config], optional):
255+
Config object for the 3-way contrastive head.
256+
Defaults to ThreeWayContrastive.Config which uses a MLP with 3 classes
257+
random_init (bool, optional):
258+
Flag to load VinVL bert weights from random_init.
259+
Defaults to False.
260+
bert_model_name (str, optional):
261+
Name for base bert model.
262+
Used for VinVL base configs and weights.
263+
Defaults to "bert-base-uncased".
264+
img_feature_dim (int, optional):
265+
The size of the VinVL image feature inputs.
266+
Defaults to 2054.
267+
use_img_layernorm (bool, optional):
268+
Flag to use layernorm on image encoding.
269+
Defaults to True.
270+
img_layer_norm_eps (float, optional):
271+
Image layernorm epsilon. Defaults to 1e-12.
272+
"""
273+
super().__init__()
274+
if mlm_config is None:
275+
mlm_config = asdict(MLM.Config())
276+
if contrast_config is None:
277+
contrast_config = asdict(ThreeWayContrastive.Config())
278+
279+
self.bert = build_vinvl_base(
280+
bert_model_name=bert_model_name,
281+
img_feature_dim=img_feature_dim,
282+
use_img_layernorm=use_img_layernorm,
283+
img_layer_norm_eps=img_layer_norm_eps,
284+
random_init=random_init,
285+
)
286+
self.mlm_head = MLM(config=mlm_config)
287+
self.ce_loss = nn.CrossEntropyLoss()
288+
self.contrast_head = ThreeWayContrastive(contrast_config)
289+
290+
def mlm_forward(
291+
self,
292+
input_ids_masked: Tensor,
293+
lm_label_ids: Tensor,
294+
token_type_ids: Tensor,
295+
attention_mask: Tensor,
296+
img_feats: Tensor,
297+
position_ids: Optional[Tensor] = None,
298+
) -> Dict[str, Tensor]:
299+
300+
hidden_layers = self.bert(
301+
input_ids_masked,
302+
img_feats=img_feats,
303+
position_ids=position_ids,
304+
token_type_ids=token_type_ids,
305+
attention_mask=attention_mask,
306+
).last_hidden_state
307+
308+
mlm_labels = {}
309+
mlm_labels["text"] = lm_label_ids
310+
mlm_labels["image"] = torch.full(
311+
img_feats.shape[:2],
312+
fill_value=-1,
313+
dtype=torch.long,
314+
device=lm_label_ids.device,
315+
)
316+
mlm_labels["combined_labels"] = torch.cat(
317+
[mlm_labels["text"], mlm_labels["image"]], dim=-1
318+
)
319+
320+
processed_sample_list = SampleList({"mlm_labels": mlm_labels})
321+
return self.mlm_head(
322+
hidden_layers, processed_sample_list=processed_sample_list
323+
)["losses"]
324+
325+
def contrastive_forward(
326+
self,
327+
input_ids: Tensor,
328+
token_type_ids: Tensor,
329+
attention_mask: Tensor,
330+
img_feats: Tensor,
331+
contrastive_labels: Tensor,
332+
position_ids: Optional[Tensor] = None,
333+
) -> Dict[str, Tensor]:
334+
335+
last_hidden_state = self.bert(
336+
input_ids,
337+
img_feats=img_feats,
338+
position_ids=position_ids,
339+
token_type_ids=token_type_ids,
340+
attention_mask=attention_mask,
341+
).last_hidden_state
342+
processed_sample_list = SampleList({"contrastive_labels": contrastive_labels})
343+
# contrastive 3-way loss has 3 classes,
344+
# 0 for a match, 1, 2 for a corrupt caption/image
345+
# labels respectively
346+
return self.contrast_head(last_hidden_state, processed_sample_list)["losses"]
347+
348+
def forward(
349+
self,
350+
input_ids_masked: Tensor,
351+
input_ids_corrupt: Tensor,
352+
lm_label_ids: Tensor,
353+
contrastive_labels: Tensor,
354+
token_type_ids: Tensor,
355+
attention_mask: Tensor,
356+
token_type_ids_corrupt: Tensor,
357+
attention_mask_corrupt: Tensor,
358+
img_feats: Tensor,
359+
position_ids: Optional[Tensor] = None,
360+
) -> Dict[str, Tensor]:
361+
362+
mlm_result = self.mlm_forward(
363+
input_ids_masked,
364+
lm_label_ids,
365+
token_type_ids,
366+
attention_mask,
367+
img_feats,
368+
position_ids,
369+
)
370+
371+
contrastive_loss_result = self.contrastive_forward(
372+
input_ids_corrupt,
373+
token_type_ids_corrupt,
374+
attention_mask_corrupt,
375+
img_feats,
376+
contrastive_labels,
377+
position_ids,
378+
)
379+
losses = {**mlm_result, **contrastive_loss_result}
380+
return {"losses": losses}

tests/models/test_vinvl.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torch
66
from mmf.models.vinvl import (
77
VinVLBase,
8+
VinVLForClassification,
9+
VinVLForPretraining,
810
)
911
from mmf.utils.general import get_current_device
1012
from transformers.modeling_bert import BertConfig
@@ -35,3 +37,59 @@ def test_forward(self):
3537
with torch.no_grad():
3638
model_output = model(input_ids, img_feat).last_hidden_state
3739
self.assertEqual(model_output.shape, torch.Size([8, 95, 768]))
40+
41+
42+
def mock_vinvl_input_tensors(
43+
cls, bs=8, num_feats=70, max_sentence_len=25, img_feature_dim=2054
44+
):
45+
cls.input_ids = torch.ones((bs, max_sentence_len), dtype=torch.long)
46+
cls.img_feats = torch.rand((bs, num_feats, img_feature_dim))
47+
cls.attention_mask = torch.ones(
48+
(bs, max_sentence_len + num_feats), dtype=torch.long
49+
)
50+
cls.token_type_ids = torch.zeros_like(cls.input_ids)
51+
cls.labels = torch.ones((bs, 1)).long()
52+
53+
cls.lm_label_ids = -torch.ones_like(cls.input_ids).long()
54+
cls.contrastive_labels = torch.zeros((bs, 1)).long()
55+
56+
57+
class TestVinVLForClassificationAndPretraining(unittest.TestCase):
58+
def setUp(self):
59+
mock_vinvl_input_tensors(self)
60+
61+
def test_classification_forward(self):
62+
model = VinVLForClassification().to(get_current_device())
63+
model.eval()
64+
65+
with torch.no_grad():
66+
model_output = model(
67+
input_ids=self.input_ids,
68+
img_feats=self.img_feats,
69+
attention_mask=self.attention_mask,
70+
token_type_ids=self.token_type_ids,
71+
labels=self.labels,
72+
)
73+
self.assertTrue("losses" in model_output)
74+
self.assertTrue("scores" in model_output)
75+
self.assertTrue("ce" in model_output["losses"])
76+
77+
def test_pretraining_forward(self):
78+
model = VinVLForPretraining().to(get_current_device())
79+
model.eval()
80+
81+
with torch.no_grad():
82+
model_output = model(
83+
img_feats=self.img_feats,
84+
attention_mask=self.attention_mask,
85+
token_type_ids=self.token_type_ids,
86+
input_ids_masked=self.input_ids,
87+
lm_label_ids=self.lm_label_ids,
88+
contrastive_labels=self.contrastive_labels,
89+
input_ids_corrupt=self.input_ids,
90+
token_type_ids_corrupt=self.token_type_ids,
91+
attention_mask_corrupt=self.attention_mask,
92+
)
93+
self.assertTrue("losses" in model_output)
94+
self.assertTrue("masked_lm_loss" in model_output["losses"])
95+
self.assertTrue("three_way_contrastive_loss" in model_output["losses"])

0 commit comments

Comments
 (0)