Skip to content

Commit d578f79

Browse files
[feat] Add VinVL model wrapper (#1151)
Summary: Pull Request resolved: #1151 Add VinVL BaseModel for training and testing from MMF. This model defers to either the classification or pretraining model depending on its config. For an example config consult the project dir or unit tests. Test Plan: ### Unit Tests Tested BaseModel instantiation from config, and forward pass for classification and pretraining. Reviewed By: ebsmothers Differential Revision: D32574738 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: 50f8396821effd778c6d5184cd940864fc1eb3b1
1 parent a51b977 commit d578f79

File tree

4 files changed

+243
-2
lines changed

4 files changed

+243
-2
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
model_config:
2+
vinvl:
3+
heads:
4+
test:
5+
type: mlp
6+
freeze: false
7+
lr_multiplier: 1.0
8+
in_dim: 768
9+
hidden_size: 1536
10+
num_labels: 3129
11+
pooler_name: bert_pooler
12+
bert_model_name: bert-base-uncased
13+
loss_type: sfmx
14+
img_feature_dim: 2054
15+
img_feature_type: 'frcnn'
16+
use_img_layernorm: 1
17+
img_layer_norm_eps: 1e-12
18+
max_img_seq_len: 70

mmf/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .uniter import UNITER
1919
from .vilbert import ViLBERT
2020
from .vilt import ViLT
21+
from .vinvl import VinVL
2122
from .visual_bert import VisualBERT
2223

2324
__all__ = [
@@ -46,4 +47,5 @@
4647
"AlbefVitEncoder",
4748
"ViLT",
4849
"UNITER",
50+
"VinVL",
4951
]

mmf/models/vinvl.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66

77
import logging
88
from collections import namedtuple
9-
from dataclasses import asdict
10-
from typing import Dict, Optional, Tuple
9+
from dataclasses import asdict, dataclass
10+
from typing import Any, Dict, Optional, Tuple
1111

1212
import torch
13+
from mmf.common.registry import registry
1314
from mmf.common.sample import SampleList
15+
from mmf.models.base_model import BaseModel
1416
from mmf.models.transformers.heads.contrastive import ThreeWayContrastive
1517
from mmf.models.transformers.heads.mlm import MLM
1618
from mmf.models.transformers.heads.mlp import MLP
1719
from mmf.utils.general import retry_n
20+
from omegaconf import MISSING, OmegaConf
1821
from torch import Tensor, nn
1922
from transformers.modeling_bert import (
2023
BertConfig,
@@ -378,3 +381,119 @@ def forward(
378381
)
379382
losses = {**mlm_result, **contrastive_loss_result}
380383
return {"losses": losses}
384+
385+
386+
@registry.register_model("vinvl")
387+
class VinVL(BaseModel):
388+
"""VinVL base model called by MMF.
389+
VinVL paper, 3-way contrastive loss:
390+
https://arxiv.org/pdf/2101.00529.pdf
391+
392+
Implementation based on https://github.com/microsoft/Oscar
393+
394+
Expects VinVL features extracted by
395+
https://github.com/microsoft/scene_graph_benchmark
396+
using Visual Genome object detection labels.
397+
398+
The label map used for training is available at
399+
https://github.com/microsoft/scene_graph_benchmark/blob/main/README.md
400+
"""
401+
402+
@dataclass
403+
class Config:
404+
random_init: bool = False
405+
bert_model_name: str = "bert-base-uncased"
406+
hidden_size: int = 768
407+
heads: Any = MISSING
408+
do_pretraining: bool = False
409+
img_feature_dim: int = 2054
410+
img_feature_type: str = "frcnn"
411+
use_img_layernorm: bool = True
412+
img_layer_norm_eps: float = 1e-12
413+
max_img_seq_len: int = 70
414+
415+
def __init__(self, config):
416+
super().__init__(config)
417+
self.config = OmegaConf.create({**asdict(self.Config()), **config})
418+
self.do_pretraining = self.config.do_pretraining
419+
420+
@classmethod
421+
def config_path(cls):
422+
return "configs/models/vinvl/defaults.yaml"
423+
424+
def build(self):
425+
if self.do_pretraining:
426+
mlm_config = self.config.heads.get("mlm")
427+
contrast_config = self.config.heads.get("contrast")
428+
self.vinvl = VinVLForPretraining(
429+
mlm_config=mlm_config, contrast_config=contrast_config, **self.config
430+
)
431+
else:
432+
# do classification finetuning
433+
mlp_config = self.config.heads.get("mlp")
434+
loss_config = self.config.get("ce_loss")
435+
self.vinvl = VinVLForClassification(
436+
mlp_config=mlp_config, loss_config=loss_config, **self.config
437+
)
438+
439+
def init_losses(self):
440+
"""
441+
Defer loss management to submodels,
442+
do nothing when called by build_model.
443+
"""
444+
445+
def forward(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]:
446+
attention_mask = self._get_attention_mask(
447+
sample_list["image_feature_0"],
448+
sample_list["image_info_0"],
449+
sample_list["input_mask"],
450+
)
451+
452+
if self.do_pretraining:
453+
corrupt_attention_mask = self._get_attention_mask(
454+
sample_list["image_feature_0"],
455+
sample_list["image_info_0"],
456+
sample_list["input_mask_corrupt"],
457+
)
458+
return self.vinvl(
459+
sample_list["input_ids_masked"],
460+
sample_list["input_ids_corrupt"],
461+
sample_list["lm_label_ids"],
462+
sample_list["contrastive_labels"],
463+
sample_list["segment_ids"],
464+
attention_mask,
465+
sample_list["segment_ids_corrupt"],
466+
corrupt_attention_mask,
467+
sample_list["image_feature_0"],
468+
)
469+
else:
470+
return self.vinvl(
471+
sample_list["input_ids"],
472+
sample_list["segment_ids"],
473+
attention_mask,
474+
sample_list["image_feature_0"],
475+
labels=sample_list.get("labels"),
476+
)
477+
478+
def _get_attention_mask(
479+
self, image_feat: Tensor, image_info: Dict[str, Tensor], input_mask: Tensor
480+
) -> Tensor:
481+
# image_dim = (bs,)
482+
# with the number of features per image in the batch as an int
483+
image_dim = image_info.get("max_features")
484+
if image_dim is None:
485+
image_mask = torch.ones(
486+
(image_feat.size(0), image_feat.size(1)), device=image_feat.device
487+
).long()
488+
else:
489+
image_mask = torch.arange(
490+
image_feat.size(-2), device=image_feat.device
491+
).expand(image_feat.size()[:-1])
492+
if len(image_dim.size()) < len(image_mask.size()):
493+
image_dim = image_dim.unsqueeze(-1)
494+
assert len(image_dim.size()) == len(image_mask.size())
495+
image_mask = image_mask < image_dim
496+
image_mask = image_mask.long()
497+
498+
attention_mask = torch.cat((input_mask, image_mask), dim=-1)
499+
return attention_mask

tests/models/test_vinvl.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22

33
import unittest
44

5+
import tests.test_utils as test_utils
56
import torch
7+
from mmf.common.sample import SampleList
68
from mmf.models.vinvl import (
79
VinVLBase,
810
VinVLForClassification,
911
VinVLForPretraining,
1012
)
13+
from mmf.utils.build import build_model
14+
from mmf.utils.configuration import Configuration
15+
from mmf.utils.env import setup_imports, teardown_imports
1116
from mmf.utils.general import get_current_device
17+
from omegaconf import OmegaConf
1218
from transformers.modeling_bert import BertConfig
1319

1420

@@ -93,3 +99,99 @@ def test_pretraining_forward(self):
9399
self.assertTrue("losses" in model_output)
94100
self.assertTrue("masked_lm_loss" in model_output["losses"])
95101
self.assertTrue("three_way_contrastive_loss" in model_output["losses"])
102+
103+
104+
class TestVinVLModel(unittest.TestCase):
105+
def setUp(self):
106+
test_utils.setup_proxy()
107+
setup_imports()
108+
model_name = "vinvl"
109+
args = test_utils.dummy_args(model=model_name, dataset="test")
110+
configuration = Configuration(args)
111+
config = configuration.get_config()
112+
model_config = config.model_config[model_name]
113+
model_config.model = model_name
114+
model_config.do_pretraining = False
115+
classification_config_dict = {
116+
"do_pretraining": False,
117+
"heads": {"mlp": {"num_labels": 3129}},
118+
"ce_loss": {"ignore_index": -1},
119+
}
120+
self.classification_config = OmegaConf.create(
121+
{**model_config, **classification_config_dict}
122+
)
123+
124+
pretraining_config_dict = {
125+
"do_pretraining": True,
126+
"heads": {"mlm": {"hidden_size": 768}},
127+
}
128+
self.pretraining_config = OmegaConf.create(
129+
{**model_config, **pretraining_config_dict}
130+
)
131+
132+
self.sample_list = self._get_sample_list()
133+
134+
def tearDown(self):
135+
teardown_imports()
136+
137+
def _get_sample_list(self):
138+
bs = 8
139+
num_feats = 70
140+
141+
class MockObj:
142+
pass
143+
144+
mock_input = MockObj()
145+
mock_vinvl_input_tensors(mock_input, bs=bs, num_feats=num_feats)
146+
147+
input_mask = torch.ones_like(mock_input.input_ids)
148+
max_features = torch.ones((bs, num_feats)) * num_feats
149+
bbox = torch.randint(50, 200, (bs, num_feats, 4)).float()
150+
image_height = torch.randint(100, 300, (bs,))
151+
image_width = torch.randint(100, 300, (bs,))
152+
image_info = {
153+
"max_features": max_features,
154+
"bbox": bbox,
155+
"image_height": image_height,
156+
"image_width": image_width,
157+
}
158+
159+
sample_list = SampleList()
160+
sample_list.add_field("input_ids", mock_input.input_ids)
161+
sample_list.add_field("input_ids_corrupt", mock_input.input_ids)
162+
sample_list.add_field("input_ids_masked", mock_input.input_ids)
163+
sample_list.add_field("image_feature_0", mock_input.img_feats)
164+
sample_list.add_field("image_info_0", image_info)
165+
sample_list.add_field("input_mask", input_mask)
166+
sample_list.add_field("input_mask_corrupt", input_mask)
167+
sample_list.add_field("segment_ids", mock_input.token_type_ids)
168+
sample_list.add_field("segment_ids_corrupt", mock_input.token_type_ids)
169+
sample_list.add_field("labels", mock_input.labels)
170+
sample_list.add_field("contrastive_labels", mock_input.contrastive_labels)
171+
sample_list.add_field("lm_label_ids", mock_input.lm_label_ids)
172+
sample_list = sample_list.to(get_current_device())
173+
sample_list.dataset_name = "test"
174+
sample_list.dataset_type = "test"
175+
return sample_list
176+
177+
def test_vinvl_for_classification(self):
178+
model_for_classification = build_model(self.classification_config)
179+
model_for_classification.eval()
180+
model_for_classification = model_for_classification.to(get_current_device())
181+
with torch.no_grad():
182+
model_output = model_for_classification(self.sample_list)
183+
184+
self.assertTrue("losses" in model_output)
185+
self.assertTrue("ce" in model_output["losses"])
186+
187+
def test_vinvl_for_pretraining(self):
188+
model_for_pretraining = build_model(self.pretraining_config)
189+
model_for_pretraining.eval()
190+
model_for_pretraining = model_for_pretraining.to(get_current_device())
191+
192+
with torch.no_grad():
193+
model_output = model_for_pretraining(self.sample_list)
194+
195+
self.assertTrue("losses" in model_output)
196+
self.assertTrue("masked_lm_loss" in model_output["losses"])
197+
self.assertTrue("three_way_contrastive_loss" in model_output["losses"])

0 commit comments

Comments
 (0)