Skip to content

Commit 6758d45

Browse files
[feat] Add UNITER model wrapper
Add UNITER model to mmf registry with support for pretraining through yaml head configs. ghstack-source-id: 028fd3e Pull Request resolved: #1127
1 parent 133adf5 commit 6758d45

File tree

4 files changed

+289
-2
lines changed

4 files changed

+289
-2
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
model_config:
2+
uniter:
3+
heads:
4+
vqa2:
5+
type: mlp
6+
freeze: false
7+
lr_multiplier: 1.0
8+
in_dim: 768
9+
hidden_size: 1536
10+
num_labels: 3129
11+
pooler_name: bert_pooler
12+
text_embeddings:
13+
type: bert_embeddings
14+
image_embeddings:
15+
type: uniter_image_embeddings
16+
params:
17+
name: 'uniter_image_embeddings'
18+
encoder:
19+
type: transformer
20+
params:
21+
bert_model_name: bert-base-uncased
22+
hidden_size: 768
23+
num_hidden_layers: 12
24+
num_attention_heads: 12
25+
output_attentions: false
26+
output_hidden_states: false
27+
tasks:
28+
- vqa2

mmf/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .visual_bert import VisualBERT
1818
from .vilbert import ViLBERT
1919
from .albef.vit import AlbefVitEncoder
20-
20+
from .uniter import UNITER
2121

2222
__all__ = [
2323
"TopDownBottomUp",
@@ -43,4 +43,5 @@
4343
"UnimodalModal",
4444
"UnimodalText",
4545
"AlbefVitEncoder",
46+
"UNITER",
4647
]

mmf/models/uniter.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
import copy
88
import logging
99
import random
10+
from dataclasses import asdict, dataclass, field
1011
from typing import Any, Dict, List, Optional, Union
1112

1213
import numpy as np
1314
import torch
1415
from mmf.common.registry import registry
16+
from mmf.models import BaseModel
1517
from mmf.modules.losses import MMFLoss
1618
from mmf.utils.general import retry_n
17-
from omegaconf import DictConfig, OmegaConf
19+
from omegaconf import MISSING, DictConfig, OmegaConf
1820
from torch import Tensor, nn
1921
from transformers.modeling_bert import BertConfig, BertEmbeddings, BertModel
2022

@@ -624,3 +626,150 @@ def _remove_mismatched_captions(self, processed_sample_list: Dict[str, Tensor]):
624626
x = x[pos_pairs_mask]
625627
else:
626628
x = x[pos_pairs_mask, ::]
629+
630+
631+
@registry.register_model("uniter")
632+
class UNITER(BaseModel):
633+
""" Modification for Joint Vision-Language Encoding
634+
"""
635+
636+
@dataclass
637+
class Config:
638+
random_init: bool = False
639+
bert_model_name: str = "bert-base-uncased"
640+
img_dim: int = 2048
641+
hidden_size: int = 768
642+
hidden_dropout_prob: float = 0
643+
text_embeddings: Any = field(default_factory=lambda: {})
644+
encoder: Any = field(default_factory=lambda: {})
645+
heads: Any = MISSING
646+
losses: Any = field(default_factory=lambda: {})
647+
tasks: Any = MISSING
648+
do_pretraining: bool = False
649+
650+
def __init__(self, config):
651+
super().__init__(config)
652+
self.config = OmegaConf.create({**asdict(self.Config()), **config})
653+
self.do_pretraining = self.config.do_pretraining
654+
655+
@classmethod
656+
def config_path(cls):
657+
return "configs/models/uniter/defaults.yaml"
658+
659+
def build(self):
660+
params = dict(
661+
**self.config,
662+
head_configs=self.config.heads,
663+
loss_configs=self.config.losses,
664+
)
665+
if self.do_pretraining:
666+
self.uniter = UNITERForPretraining(**params)
667+
else:
668+
self.uniter = UNITERForClassification(**params)
669+
670+
self.tasks = self.config.tasks
671+
if isinstance(self.tasks, str):
672+
self.tasks = self.tasks.split(",")
673+
674+
def init_losses(self):
675+
"""
676+
Defer loss management to submodels,
677+
do nothing when called by build_model.
678+
"""
679+
680+
def add_pos_feat(self, sample_list: Dict[str, Tensor]):
681+
assert "image_info_0" in sample_list
682+
assert "bbox" in sample_list["image_info_0"]
683+
684+
# (x1, y1, x2, y2), dim = (bs, num_feats, 4)
685+
bboxs = torch.tensor(sample_list["image_info_0"]["bbox"])[:, :, :4]
686+
norm_xy = torch.clone(bboxs)
687+
# if bboxs are not normalized, just do it here
688+
if norm_xy[0, 0, 0] < 1:
689+
img_h = (
690+
torch.tensor(sample_list["image_info_0"]["image_height"])
691+
.unsqueeze(1)
692+
.unsqueeze(1)
693+
) # (bs,)
694+
img_w = (
695+
torch.tensor(sample_list["image_info_0"]["image_width"])
696+
.unsqueeze(1)
697+
.unsqueeze(1)
698+
) # (bs,)
699+
max_image_size = torch.cat([img_w, img_h, img_w, img_h], dim=-1)
700+
max_image_size = max_image_size.to(norm_xy.device)
701+
norm_xy /= max_image_size
702+
703+
bbox_w = (norm_xy[:, :, 2] - norm_xy[:, :, 0]).unsqueeze(-1)
704+
bbox_h = (norm_xy[:, :, 3] - norm_xy[:, :, 1]).unsqueeze(-1)
705+
area = bbox_w * bbox_h
706+
# normalized (x1, y1, x2, y2, w, h, area)
707+
pos_feat = torch.cat([norm_xy, bbox_w, bbox_h, area], dim=-1).to(
708+
sample_list["image_feature_0"]
709+
)
710+
sample_list["img_pos_feat"] = pos_feat
711+
712+
def add_custom_params(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]:
713+
image_feat = sample_list["image_feat"] = sample_list["image_feature_0"]
714+
715+
image_info = getattr(sample_list, "image_info_0", {})
716+
image_dim = getattr(image_info, "max_features", None)
717+
sample_list["image_dim"] = image_dim
718+
719+
image_mask = torch.arange(image_feat.size(-2), device=image_feat.device).expand(
720+
image_feat.size()[:-1]
721+
)
722+
if len(image_dim.size()) < len(image_mask.size()):
723+
image_dim = image_dim.unsqueeze(-1)
724+
assert len(image_dim.size()) == len(image_mask.size())
725+
image_mask = image_mask < image_dim
726+
sample_list["image_mask"] = image_mask.long()
727+
728+
attention_mask = torch.cat(
729+
(sample_list["input_mask"], sample_list["image_mask"]), dim=-1
730+
)
731+
sample_list["attention_mask"] = attention_mask
732+
task_index = torch.randint(len(self.tasks), (1,)).item()
733+
sample_list["task"] = self.tasks[task_index]
734+
sample_list["position_ids"] = torch.arange(
735+
0,
736+
sample_list["input_ids"].size(1),
737+
dtype=torch.long,
738+
device=image_feat.device,
739+
).unsqueeze(0)
740+
741+
self.add_pos_feat(sample_list)
742+
return sample_list
743+
744+
def forward(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]:
745+
sample_list = self.add_custom_params(sample_list)
746+
return self.uniter(sample_list)
747+
748+
def get_attention_mask(
749+
self,
750+
sample_list: Dict[str, Tensor],
751+
text_embedding: Tensor,
752+
image_embedding: Tensor,
753+
) -> Tensor:
754+
image_mask = getattr(sample_list, "image_mask", None)
755+
756+
if image_mask is not None and sample_list.input_mask is not None:
757+
attention_mask = torch.cat((sample_list.input_mask, image_mask), dim=-1)
758+
elif image_mask is not None:
759+
text_mask = torch.ones(
760+
text_embedding.size()[:-1],
761+
dtype=text_embedding.dtype,
762+
device=text_embedding.device,
763+
)
764+
attention_mask = torch.cat((image_mask, text_mask), dim=-1)
765+
elif sample_list.input_mask is not None:
766+
image_mask = torch.ones(
767+
image_embedding.size()[:-1],
768+
dtype=image_embedding.dtype,
769+
device=image_embedding.device,
770+
)
771+
attention_mask = torch.cat((image_mask, sample_list.input_mask), dim=-1)
772+
else:
773+
attention_mask = None
774+
775+
return attention_mask

tests/models/test_uniter.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import gc
23
import unittest
34

5+
import tests.test_utils as test_utils
46
import torch
57
from mmf.common.sample import SampleList
68
from mmf.models.uniter import (
@@ -9,6 +11,9 @@
911
UNITERImageEmbeddings,
1012
UNITERModelBase,
1113
)
14+
from mmf.utils.build import build_model
15+
from mmf.utils.configuration import Configuration
16+
from mmf.utils.env import setup_imports, teardown_imports
1217
from mmf.utils.general import get_current_device
1318
from omegaconf import OmegaConf
1419

@@ -166,3 +171,107 @@ def test_uniter_for_pretraining(self):
166171

167172
self.assertTrue("losses" in model_output)
168173
self.assertTrue(loss_name in model_output["losses"])
174+
175+
176+
class TestUniterModel(unittest.TestCase):
177+
def setUp(self):
178+
test_utils.setup_proxy()
179+
setup_imports()
180+
model_name = "uniter"
181+
args = test_utils.dummy_args(model=model_name, dataset="vqa2")
182+
configuration = Configuration(args)
183+
config = configuration.get_config()
184+
model_config = config.model_config[model_name]
185+
model_config.model = model_name
186+
model_config.losses = {"vqa2": "logit_bce"}
187+
model_config.do_pretraining = False
188+
model_config.tasks = "vqa2"
189+
classification_config_dict = {
190+
"do_pretraining": False,
191+
"tasks": "vqa2",
192+
"heads": {"vqa2": {"type": "mlp", "num_labels": 3129}},
193+
"losses": {"vqa2": "logit_bce"},
194+
}
195+
classification_config = OmegaConf.create(
196+
{**model_config, **classification_config_dict}
197+
)
198+
199+
pretraining_config_dict = {
200+
"do_pretraining": True,
201+
"tasks": "wra",
202+
"heads": {"wra": {"type": "wra"}},
203+
}
204+
pretraining_config = OmegaConf.create(
205+
{**model_config, **pretraining_config_dict}
206+
)
207+
208+
self.model_for_classification = build_model(classification_config)
209+
self.model_for_pretraining = build_model(pretraining_config)
210+
211+
def tearDown(self):
212+
teardown_imports()
213+
del self.model_for_classification
214+
del self.model_for_pretraining
215+
gc.collect()
216+
217+
def _get_sample_list(self):
218+
bs = 8
219+
num_feats = 100
220+
max_sentence_len = 25
221+
img_dim = 2048
222+
vqa_cls_dim = 3129
223+
input_ids = torch.ones((bs, max_sentence_len), dtype=torch.long)
224+
input_mask = torch.ones((bs, max_sentence_len), dtype=torch.long)
225+
img_feat = torch.rand((bs, num_feats, img_dim))
226+
227+
max_features = torch.ones((bs, num_feats)) * num_feats
228+
bbox = torch.randint(50, 200, (bs, num_feats, 4)).float()
229+
image_height = torch.randint(100, 300, (bs,))
230+
image_width = torch.randint(100, 300, (bs,))
231+
image_info = {
232+
"max_features": max_features,
233+
"bbox": bbox,
234+
"image_height": image_height,
235+
"image_width": image_width,
236+
}
237+
targets = torch.rand((bs, vqa_cls_dim))
238+
is_correct = torch.ones((bs,), dtype=torch.long)
239+
240+
sample_list = SampleList()
241+
sample_list.add_field("input_ids", input_ids)
242+
sample_list.add_field("image_feature_0", img_feat)
243+
sample_list.add_field("input_mask", input_mask)
244+
sample_list.add_field("image_info_0", image_info)
245+
sample_list.add_field("targets", targets)
246+
sample_list.add_field("is_correct", is_correct)
247+
sample_list = sample_list.to(get_current_device())
248+
return sample_list
249+
250+
def test_uniter_for_classification(self):
251+
self.model_for_classification.eval()
252+
self.model_for_classification = self.model_for_classification.to(
253+
get_current_device()
254+
)
255+
sample_list = self._get_sample_list()
256+
257+
sample_list.dataset_name = "vqa2"
258+
sample_list.dataset_type = "test"
259+
with torch.no_grad():
260+
model_output = self.model_for_classification(sample_list)
261+
262+
self.assertTrue("losses" in model_output)
263+
self.assertTrue("test/vqa2/logit_bce" in model_output["losses"])
264+
265+
def test_uniter_for_pretraining(self):
266+
self.model_for_pretraining.eval()
267+
self.model_for_pretraining = self.model_for_pretraining.to(get_current_device())
268+
sample_list = self._get_sample_list()
269+
sample_list["tasks"] = "wra"
270+
271+
sample_list.dataset_name = "vqa2"
272+
sample_list.dataset_type = "test"
273+
with torch.no_grad():
274+
model_output = self.model_for_pretraining(sample_list)
275+
276+
self.assertTrue("losses" in model_output)
277+
self.assertTrue("wra_loss" in model_output["losses"])

0 commit comments

Comments
 (0)