Skip to content

Commit 7c3b527

Browse files
[feat] Add UNITER model wrapper (#1127)
Summary: Pull Request resolved: #1127 Add UNITER BaseModel for training and testing from MMF. This model defers to either the classification or pretraining model depending on its config. For an example config consult the project dir or unit tests. Test Plan: ### Unit Tests Tested forward passes for classification and pertaining from mmf basemodel build from config. ### End to End Model tested end-to-end using butd extracted features on winoground. Will work on converting caffe feature extraction to pytorch so we can add Uniter and Villa checkpoints usefully. Reviewed By: ebsmothers Differential Revision: D31768457 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: b311419f4b1431a2cf8bb5322bd08d80e8a883c3
1 parent 426de65 commit 7c3b527

File tree

4 files changed

+305
-4
lines changed

4 files changed

+305
-4
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
model_config:
2+
uniter:
3+
heads:
4+
vqa2:
5+
type: mlp
6+
freeze: false
7+
lr_multiplier: 1.0
8+
in_dim: 768
9+
hidden_size: 1536
10+
num_labels: 3129
11+
pooler_name: bert_pooler
12+
text_embeddings:
13+
type: bert_embeddings
14+
image_embeddings:
15+
type: uniter_image_embeddings
16+
params:
17+
name: 'uniter_image_embeddings'
18+
encoder:
19+
type: transformer
20+
params:
21+
bert_model_name: bert-base-uncased
22+
hidden_size: 768
23+
num_hidden_layers: 12
24+
num_attention_heads: 12
25+
output_attentions: false
26+
output_hidden_states: false
27+
tasks:
28+
- vqa2

mmf/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from .pythia import Pythia
1616
from .top_down_bottom_up import TopDownBottomUp
1717
from .unimodal import UnimodalBase, UnimodalText, UnimodalModal
18+
from .uniter import UNITER
1819
from .vilbert import ViLBERT
1920
from .vilt import ViLT
2021
from .visual_bert import VisualBERT
2122

22-
2323
__all__ = [
2424
"TopDownBottomUp",
2525
"Pythia",
@@ -45,4 +45,5 @@
4545
"UnimodalText",
4646
"AlbefVitEncoder",
4747
"ViLT",
48+
"UNITER",
4849
]

mmf/models/uniter.py

Lines changed: 166 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
import logging
88
import random
99
from collections import MutableMapping, namedtuple
10+
from dataclasses import asdict, dataclass, field
1011
from typing import Any, Dict, List, Optional, Tuple, Union
1112

1213
import numpy as np
1314
import torch
1415
from mmf.common.registry import registry
16+
from mmf.models import BaseModel
1517
from mmf.modules.losses import MMFLoss
1618
from mmf.utils.general import retry_n
17-
from omegaconf import DictConfig, OmegaConf
19+
from omegaconf import MISSING, DictConfig, OmegaConf
1820
from torch import Tensor, nn
1921
from transformers.modeling_bert import BertConfig, BertEmbeddings, BertModel
2022

@@ -310,8 +312,8 @@ def __init__(
310312
for task in self.tasks:
311313
assert task in head_configs, (
312314
f"Task {task} is specified in your model configs"
313-
+ " but there is no head configured for the task."
314-
+ "Head configs can be added under model_config.heads"
315+
+ " but there is no head configured for the task. "
316+
+ "Head configs can be added under model_config.heads "
315317
+ "in your yaml configs. Either remove this task if UNITER"
316318
+ " is not meant to run on a dataset named {task}"
317319
+ " or add a head config."
@@ -603,3 +605,164 @@ def _remove_mismatched_captions(self, processed_sample_list: Dict[str, Tensor]):
603605
x = x[pos_pairs_mask]
604606
else:
605607
x = x[pos_pairs_mask, ::]
608+
609+
610+
@registry.register_model("uniter")
611+
class UNITER(BaseModel):
612+
"""Modification for Joint Vision-Language Encoding"""
613+
614+
@dataclass
615+
class Config:
616+
random_init: bool = False
617+
bert_model_name: str = "bert-base-uncased"
618+
img_dim: int = 2048
619+
hidden_size: int = 768
620+
hidden_dropout_prob: float = 0
621+
text_embeddings: Any = field(default_factory=lambda: {})
622+
encoder: Any = field(default_factory=lambda: {})
623+
heads: Any = MISSING
624+
losses: Any = field(default_factory=lambda: {})
625+
tasks: Any = MISSING
626+
do_pretraining: bool = False
627+
628+
def __init__(self, config):
629+
super().__init__(config)
630+
self.config = OmegaConf.create({**asdict(self.Config()), **config})
631+
self.do_pretraining = self.config.do_pretraining
632+
633+
@classmethod
634+
def config_path(cls):
635+
return "configs/models/uniter/defaults.yaml"
636+
637+
def build(self):
638+
configs = dict(**self.config)
639+
configs["head_configs"] = configs.pop("heads")
640+
configs["loss_configs"] = configs.pop("losses")
641+
params_keys = [
642+
"head_configs",
643+
"loss_configs",
644+
"tasks",
645+
"random_init",
646+
"bert_model_name",
647+
"img_dim",
648+
"hidden_size",
649+
"hidden_dropout_prob",
650+
"text_embeddings",
651+
"encoder",
652+
]
653+
if self.do_pretraining:
654+
# take value from config when the key exists,
655+
# otherwise use constructor defaults
656+
params_keys += ["mask_probability"]
657+
params = {key: configs[key] for key in params_keys if key in configs}
658+
self.uniter = UNITERForPretraining(**params)
659+
else:
660+
params = {key: configs[key] for key in params_keys if key in configs}
661+
self.uniter = UNITERForClassification(**params)
662+
663+
self.tasks = self.config.tasks
664+
if isinstance(self.tasks, str):
665+
self.tasks = self.tasks.split(",")
666+
667+
def init_losses(self):
668+
"""
669+
Defer loss management to submodels,
670+
do nothing when called by build_model.
671+
"""
672+
pass
673+
674+
def add_pos_feat(self, sample_list: Dict[str, Tensor]):
675+
assert "image_info_0" in sample_list
676+
assert "bbox" in sample_list["image_info_0"]
677+
678+
# (x1, y1, x2, y2), dim = (bs, num_feats, 4)
679+
bboxs = torch.tensor(sample_list["image_info_0"]["bbox"])[:, :, :4]
680+
norm_xy = torch.clone(bboxs)
681+
# if bboxs are not normalized, just do it here
682+
if norm_xy[0, 0, 0] < 1:
683+
img_h = (
684+
torch.tensor(sample_list["image_info_0"]["image_height"])
685+
.unsqueeze(1)
686+
.unsqueeze(1)
687+
) # (bs,)
688+
img_w = (
689+
torch.tensor(sample_list["image_info_0"]["image_width"])
690+
.unsqueeze(1)
691+
.unsqueeze(1)
692+
) # (bs,)
693+
max_image_size = torch.cat([img_w, img_h, img_w, img_h], dim=-1)
694+
max_image_size = max_image_size.to(norm_xy.device)
695+
norm_xy /= max_image_size
696+
697+
bbox_w = (norm_xy[:, :, 2] - norm_xy[:, :, 0]).unsqueeze(-1)
698+
bbox_h = (norm_xy[:, :, 3] - norm_xy[:, :, 1]).unsqueeze(-1)
699+
area = bbox_w * bbox_h
700+
# normalized (x1, y1, x2, y2, w, h, area)
701+
pos_feat = torch.cat([norm_xy, bbox_w, bbox_h, area], dim=-1).to(
702+
sample_list["image_feature_0"]
703+
)
704+
sample_list["img_pos_feat"] = pos_feat
705+
706+
def add_custom_params(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]:
707+
image_feat = sample_list["image_feat"] = sample_list["image_feature_0"]
708+
709+
image_info = getattr(sample_list, "image_info_0", {})
710+
image_dim = getattr(image_info, "max_features", None)
711+
sample_list["image_dim"] = image_dim
712+
713+
image_mask = torch.arange(image_feat.size(-2), device=image_feat.device).expand(
714+
image_feat.size()[:-1]
715+
)
716+
if len(image_dim.size()) < len(image_mask.size()):
717+
image_dim = image_dim.unsqueeze(-1)
718+
assert len(image_dim.size()) == len(image_mask.size())
719+
image_mask = image_mask < image_dim
720+
sample_list["image_mask"] = image_mask.long()
721+
722+
sample_list["attention_mask"] = torch.cat(
723+
(sample_list["input_mask"], sample_list["image_mask"]), dim=-1
724+
)
725+
task_index = torch.randint(len(self.tasks), (1,)).item()
726+
sample_list["task"] = self.tasks[task_index]
727+
sample_list["position_ids"] = torch.arange(
728+
0,
729+
sample_list["input_ids"].size(1),
730+
dtype=torch.long,
731+
device=image_feat.device,
732+
).unsqueeze(0)
733+
734+
self.add_pos_feat(sample_list)
735+
return sample_list
736+
737+
def forward(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]:
738+
sample_list = self.add_custom_params(sample_list)
739+
return self.uniter(sample_list)
740+
741+
def get_attention_mask(
742+
self,
743+
sample_list: Dict[str, Tensor],
744+
text_embedding: Tensor,
745+
image_embedding: Tensor,
746+
) -> Tensor:
747+
image_mask = getattr(sample_list, "image_mask", None)
748+
749+
if image_mask is not None and sample_list.input_mask is not None:
750+
attention_mask = torch.cat((sample_list.input_mask, image_mask), dim=-1)
751+
elif image_mask is not None:
752+
text_mask = torch.ones(
753+
text_embedding.size()[:-1],
754+
dtype=text_embedding.dtype,
755+
device=text_embedding.device,
756+
)
757+
attention_mask = torch.cat((image_mask, text_mask), dim=-1)
758+
elif sample_list.input_mask is not None:
759+
image_mask = torch.ones(
760+
image_embedding.size()[:-1],
761+
dtype=image_embedding.dtype,
762+
device=image_embedding.device,
763+
)
764+
attention_mask = torch.cat((image_mask, sample_list.input_mask), dim=-1)
765+
else:
766+
attention_mask = None
767+
768+
return attention_mask

tests/models/test_uniter.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import gc
23
import unittest
34

5+
import tests.test_utils as test_utils
46
import torch
57
from mmf.common.sample import SampleList
68
from mmf.models.uniter import (
@@ -9,6 +11,9 @@
911
UNITERImageEmbeddings,
1012
UNITERModelBase,
1113
)
14+
from mmf.utils.build import build_model
15+
from mmf.utils.configuration import Configuration
16+
from mmf.utils.env import setup_imports, teardown_imports
1217
from mmf.utils.general import get_current_device
1318
from omegaconf import OmegaConf
1419

@@ -166,3 +171,107 @@ def test_uniter_for_pretraining(self):
166171

167172
self.assertTrue("losses" in model_output)
168173
self.assertTrue(loss_name in model_output["losses"])
174+
175+
176+
class TestUniterModel(unittest.TestCase):
177+
def setUp(self):
178+
test_utils.setup_proxy()
179+
setup_imports()
180+
model_name = "uniter"
181+
args = test_utils.dummy_args(model=model_name, dataset="vqa2")
182+
configuration = Configuration(args)
183+
config = configuration.get_config()
184+
model_config = config.model_config[model_name]
185+
model_config.model = model_name
186+
model_config.losses = {"vqa2": "logit_bce"}
187+
model_config.do_pretraining = False
188+
model_config.tasks = "vqa2"
189+
classification_config_dict = {
190+
"do_pretraining": False,
191+
"tasks": "vqa2",
192+
"heads": {"vqa2": {"type": "mlp", "num_labels": 3129}},
193+
"losses": {"vqa2": "logit_bce"},
194+
}
195+
classification_config = OmegaConf.create(
196+
{**model_config, **classification_config_dict}
197+
)
198+
199+
pretraining_config_dict = {
200+
"do_pretraining": True,
201+
"tasks": "wra",
202+
"heads": {"wra": {"type": "wra"}},
203+
}
204+
pretraining_config = OmegaConf.create(
205+
{**model_config, **pretraining_config_dict}
206+
)
207+
208+
self.model_for_classification = build_model(classification_config)
209+
self.model_for_pretraining = build_model(pretraining_config)
210+
211+
def tearDown(self):
212+
teardown_imports()
213+
del self.model_for_classification
214+
del self.model_for_pretraining
215+
gc.collect()
216+
217+
def _get_sample_list(self):
218+
bs = 8
219+
num_feats = 100
220+
max_sentence_len = 25
221+
img_dim = 2048
222+
vqa_cls_dim = 3129
223+
input_ids = torch.ones((bs, max_sentence_len), dtype=torch.long)
224+
input_mask = torch.ones((bs, max_sentence_len), dtype=torch.long)
225+
img_feat = torch.rand((bs, num_feats, img_dim))
226+
227+
max_features = torch.ones((bs, num_feats)) * num_feats
228+
bbox = torch.randint(50, 200, (bs, num_feats, 4)).float()
229+
image_height = torch.randint(100, 300, (bs,))
230+
image_width = torch.randint(100, 300, (bs,))
231+
image_info = {
232+
"max_features": max_features,
233+
"bbox": bbox,
234+
"image_height": image_height,
235+
"image_width": image_width,
236+
}
237+
targets = torch.rand((bs, vqa_cls_dim))
238+
is_correct = torch.ones((bs,), dtype=torch.long)
239+
240+
sample_list = SampleList()
241+
sample_list.add_field("input_ids", input_ids)
242+
sample_list.add_field("image_feature_0", img_feat)
243+
sample_list.add_field("input_mask", input_mask)
244+
sample_list.add_field("image_info_0", image_info)
245+
sample_list.add_field("targets", targets)
246+
sample_list.add_field("is_correct", is_correct)
247+
sample_list = sample_list.to(get_current_device())
248+
return sample_list
249+
250+
def test_uniter_for_classification(self):
251+
self.model_for_classification.eval()
252+
self.model_for_classification = self.model_for_classification.to(
253+
get_current_device()
254+
)
255+
sample_list = self._get_sample_list()
256+
257+
sample_list.dataset_name = "vqa2"
258+
sample_list.dataset_type = "test"
259+
with torch.no_grad():
260+
model_output = self.model_for_classification(sample_list)
261+
262+
self.assertTrue("losses" in model_output)
263+
self.assertTrue("test/vqa2/logit_bce" in model_output["losses"])
264+
265+
def test_uniter_for_pretraining(self):
266+
self.model_for_pretraining.eval()
267+
self.model_for_pretraining = self.model_for_pretraining.to(get_current_device())
268+
sample_list = self._get_sample_list()
269+
sample_list["tasks"] = "wra"
270+
271+
sample_list.dataset_name = "vqa2"
272+
sample_list.dataset_type = "test"
273+
with torch.no_grad():
274+
model_output = self.model_for_pretraining(sample_list)
275+
276+
self.assertTrue("losses" in model_output)
277+
self.assertTrue("wra_loss" in model_output["losses"])

0 commit comments

Comments
 (0)