|
7 | 7 | import logging |
8 | 8 | import random |
9 | 9 | from collections import MutableMapping, namedtuple |
| 10 | +from dataclasses import asdict, dataclass, field |
10 | 11 | from typing import Any, Dict, List, Optional, Tuple, Union |
11 | 12 |
|
12 | 13 | import numpy as np |
13 | 14 | import torch |
14 | 15 | from mmf.common.registry import registry |
| 16 | +from mmf.models import BaseModel |
15 | 17 | from mmf.modules.losses import MMFLoss |
16 | 18 | from mmf.utils.general import retry_n |
17 | | -from omegaconf import DictConfig, OmegaConf |
| 19 | +from omegaconf import MISSING, DictConfig, OmegaConf |
18 | 20 | from torch import Tensor, nn |
19 | 21 | from transformers.modeling_bert import BertConfig, BertEmbeddings, BertModel |
20 | 22 |
|
@@ -310,8 +312,8 @@ def __init__( |
310 | 312 | for task in self.tasks: |
311 | 313 | assert task in head_configs, ( |
312 | 314 | f"Task {task} is specified in your model configs" |
313 | | - + " but there is no head configured for the task." |
314 | | - + "Head configs can be added under model_config.heads" |
| 315 | + + " but there is no head configured for the task. " |
| 316 | + + "Head configs can be added under model_config.heads " |
315 | 317 | + "in your yaml configs. Either remove this task if UNITER" |
316 | 318 | + " is not meant to run on a dataset named {task}" |
317 | 319 | + " or add a head config." |
@@ -603,3 +605,164 @@ def _remove_mismatched_captions(self, processed_sample_list: Dict[str, Tensor]): |
603 | 605 | x = x[pos_pairs_mask] |
604 | 606 | else: |
605 | 607 | x = x[pos_pairs_mask, ::] |
| 608 | + |
| 609 | + |
| 610 | +@registry.register_model("uniter") |
| 611 | +class UNITER(BaseModel): |
| 612 | + """Modification for Joint Vision-Language Encoding""" |
| 613 | + |
| 614 | + @dataclass |
| 615 | + class Config: |
| 616 | + random_init: bool = False |
| 617 | + bert_model_name: str = "bert-base-uncased" |
| 618 | + img_dim: int = 2048 |
| 619 | + hidden_size: int = 768 |
| 620 | + hidden_dropout_prob: float = 0 |
| 621 | + text_embeddings: Any = field(default_factory=lambda: {}) |
| 622 | + encoder: Any = field(default_factory=lambda: {}) |
| 623 | + heads: Any = MISSING |
| 624 | + losses: Any = field(default_factory=lambda: {}) |
| 625 | + tasks: Any = MISSING |
| 626 | + do_pretraining: bool = False |
| 627 | + |
| 628 | + def __init__(self, config): |
| 629 | + super().__init__(config) |
| 630 | + self.config = OmegaConf.create({**asdict(self.Config()), **config}) |
| 631 | + self.do_pretraining = self.config.do_pretraining |
| 632 | + |
| 633 | + @classmethod |
| 634 | + def config_path(cls): |
| 635 | + return "configs/models/uniter/defaults.yaml" |
| 636 | + |
| 637 | + def build(self): |
| 638 | + configs = dict(**self.config) |
| 639 | + configs["head_configs"] = configs.pop("heads") |
| 640 | + configs["loss_configs"] = configs.pop("losses") |
| 641 | + params_keys = [ |
| 642 | + "head_configs", |
| 643 | + "loss_configs", |
| 644 | + "tasks", |
| 645 | + "random_init", |
| 646 | + "bert_model_name", |
| 647 | + "img_dim", |
| 648 | + "hidden_size", |
| 649 | + "hidden_dropout_prob", |
| 650 | + "text_embeddings", |
| 651 | + "encoder", |
| 652 | + ] |
| 653 | + if self.do_pretraining: |
| 654 | + # take value from config when the key exists, |
| 655 | + # otherwise use constructor defaults |
| 656 | + params_keys += ["mask_probability"] |
| 657 | + params = {key: configs[key] for key in params_keys if key in configs} |
| 658 | + self.uniter = UNITERForPretraining(**params) |
| 659 | + else: |
| 660 | + params = {key: configs[key] for key in params_keys if key in configs} |
| 661 | + self.uniter = UNITERForClassification(**params) |
| 662 | + |
| 663 | + self.tasks = self.config.tasks |
| 664 | + if isinstance(self.tasks, str): |
| 665 | + self.tasks = self.tasks.split(",") |
| 666 | + |
| 667 | + def init_losses(self): |
| 668 | + """ |
| 669 | + Defer loss management to submodels, |
| 670 | + do nothing when called by build_model. |
| 671 | + """ |
| 672 | + pass |
| 673 | + |
| 674 | + def add_pos_feat(self, sample_list: Dict[str, Tensor]): |
| 675 | + assert "image_info_0" in sample_list |
| 676 | + assert "bbox" in sample_list["image_info_0"] |
| 677 | + |
| 678 | + # (x1, y1, x2, y2), dim = (bs, num_feats, 4) |
| 679 | + bboxs = torch.tensor(sample_list["image_info_0"]["bbox"])[:, :, :4] |
| 680 | + norm_xy = torch.clone(bboxs) |
| 681 | + # if bboxs are not normalized, just do it here |
| 682 | + if norm_xy[0, 0, 0] < 1: |
| 683 | + img_h = ( |
| 684 | + torch.tensor(sample_list["image_info_0"]["image_height"]) |
| 685 | + .unsqueeze(1) |
| 686 | + .unsqueeze(1) |
| 687 | + ) # (bs,) |
| 688 | + img_w = ( |
| 689 | + torch.tensor(sample_list["image_info_0"]["image_width"]) |
| 690 | + .unsqueeze(1) |
| 691 | + .unsqueeze(1) |
| 692 | + ) # (bs,) |
| 693 | + max_image_size = torch.cat([img_w, img_h, img_w, img_h], dim=-1) |
| 694 | + max_image_size = max_image_size.to(norm_xy.device) |
| 695 | + norm_xy /= max_image_size |
| 696 | + |
| 697 | + bbox_w = (norm_xy[:, :, 2] - norm_xy[:, :, 0]).unsqueeze(-1) |
| 698 | + bbox_h = (norm_xy[:, :, 3] - norm_xy[:, :, 1]).unsqueeze(-1) |
| 699 | + area = bbox_w * bbox_h |
| 700 | + # normalized (x1, y1, x2, y2, w, h, area) |
| 701 | + pos_feat = torch.cat([norm_xy, bbox_w, bbox_h, area], dim=-1).to( |
| 702 | + sample_list["image_feature_0"] |
| 703 | + ) |
| 704 | + sample_list["img_pos_feat"] = pos_feat |
| 705 | + |
| 706 | + def add_custom_params(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]: |
| 707 | + image_feat = sample_list["image_feat"] = sample_list["image_feature_0"] |
| 708 | + |
| 709 | + image_info = getattr(sample_list, "image_info_0", {}) |
| 710 | + image_dim = getattr(image_info, "max_features", None) |
| 711 | + sample_list["image_dim"] = image_dim |
| 712 | + |
| 713 | + image_mask = torch.arange(image_feat.size(-2), device=image_feat.device).expand( |
| 714 | + image_feat.size()[:-1] |
| 715 | + ) |
| 716 | + if len(image_dim.size()) < len(image_mask.size()): |
| 717 | + image_dim = image_dim.unsqueeze(-1) |
| 718 | + assert len(image_dim.size()) == len(image_mask.size()) |
| 719 | + image_mask = image_mask < image_dim |
| 720 | + sample_list["image_mask"] = image_mask.long() |
| 721 | + |
| 722 | + sample_list["attention_mask"] = torch.cat( |
| 723 | + (sample_list["input_mask"], sample_list["image_mask"]), dim=-1 |
| 724 | + ) |
| 725 | + task_index = torch.randint(len(self.tasks), (1,)).item() |
| 726 | + sample_list["task"] = self.tasks[task_index] |
| 727 | + sample_list["position_ids"] = torch.arange( |
| 728 | + 0, |
| 729 | + sample_list["input_ids"].size(1), |
| 730 | + dtype=torch.long, |
| 731 | + device=image_feat.device, |
| 732 | + ).unsqueeze(0) |
| 733 | + |
| 734 | + self.add_pos_feat(sample_list) |
| 735 | + return sample_list |
| 736 | + |
| 737 | + def forward(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]: |
| 738 | + sample_list = self.add_custom_params(sample_list) |
| 739 | + return self.uniter(sample_list) |
| 740 | + |
| 741 | + def get_attention_mask( |
| 742 | + self, |
| 743 | + sample_list: Dict[str, Tensor], |
| 744 | + text_embedding: Tensor, |
| 745 | + image_embedding: Tensor, |
| 746 | + ) -> Tensor: |
| 747 | + image_mask = getattr(sample_list, "image_mask", None) |
| 748 | + |
| 749 | + if image_mask is not None and sample_list.input_mask is not None: |
| 750 | + attention_mask = torch.cat((sample_list.input_mask, image_mask), dim=-1) |
| 751 | + elif image_mask is not None: |
| 752 | + text_mask = torch.ones( |
| 753 | + text_embedding.size()[:-1], |
| 754 | + dtype=text_embedding.dtype, |
| 755 | + device=text_embedding.device, |
| 756 | + ) |
| 757 | + attention_mask = torch.cat((image_mask, text_mask), dim=-1) |
| 758 | + elif sample_list.input_mask is not None: |
| 759 | + image_mask = torch.ones( |
| 760 | + image_embedding.size()[:-1], |
| 761 | + dtype=image_embedding.dtype, |
| 762 | + device=image_embedding.device, |
| 763 | + ) |
| 764 | + attention_mask = torch.cat((image_mask, sample_list.input_mask), dim=-1) |
| 765 | + else: |
| 766 | + attention_mask = None |
| 767 | + |
| 768 | + return attention_mask |
0 commit comments