|
6 | 6 |
|
7 | 7 | import logging |
8 | 8 | from collections import namedtuple |
9 | | -from typing import Optional, Tuple |
| 9 | +from dataclasses import asdict |
| 10 | +from typing import Dict, Optional, Tuple |
10 | 11 |
|
11 | 12 | import torch |
| 13 | +from mmf.common.sample import SampleList |
| 14 | +from mmf.models.transformers.heads.contrastive import ThreeWayContrastive |
| 15 | +from mmf.models.transformers.heads.mlm import MLM |
| 16 | +from mmf.models.transformers.heads.mlp import MLP |
| 17 | +from mmf.utils.general import retry_n |
12 | 18 | from torch import Tensor, nn |
13 | 19 | from transformers.modeling_bert import ( |
14 | 20 | BertConfig, |
|
19 | 25 |
|
20 | 26 | logger = logging.getLogger(__name__) |
21 | 27 |
|
| 28 | +NUM_RETRIES = 6 |
| 29 | + |
22 | 30 |
|
23 | 31 | class VinVLBase(BertPreTrainedModel): |
24 | 32 | """VinVL Bert Encoder for image features |
@@ -99,3 +107,274 @@ def forward( |
99 | 107 | ) |
100 | 108 | layers = namedtuple("TransformerOutput", ["last_hidden_state", "hidden_layers"]) |
101 | 109 | return layers(encoder_outputs[0], encoder_outputs[1]) |
| 110 | + |
| 111 | + |
| 112 | +def build_vinvl_base( |
| 113 | + bert_model_name: str = "bert-base-uncased", |
| 114 | + img_feature_dim: int = 2054, |
| 115 | + use_img_layernorm: bool = True, |
| 116 | + img_layer_norm_eps: float = 1e-12, |
| 117 | + random_init: bool = True, |
| 118 | +) -> VinVLBase: |
| 119 | + bert_config = retry_n( |
| 120 | + NUM_RETRIES, |
| 121 | + BertConfig.from_pretrained, |
| 122 | + bert_model_name, |
| 123 | + ) |
| 124 | + # augment hf BertConfig for vinvl BertImgModel config |
| 125 | + bert_config.img_feature_dim = img_feature_dim |
| 126 | + bert_config.use_img_layernorm = use_img_layernorm |
| 127 | + bert_config.img_layer_norm_eps = img_layer_norm_eps |
| 128 | + |
| 129 | + if random_init: |
| 130 | + bert = VinVLBase(bert_config) |
| 131 | + else: |
| 132 | + bert = retry_n( |
| 133 | + NUM_RETRIES, |
| 134 | + VinVLBase.from_pretrained, |
| 135 | + bert_model_name, |
| 136 | + config=bert_config, |
| 137 | + ) |
| 138 | + return bert |
| 139 | + |
| 140 | + |
| 141 | +class VinVLForClassification(nn.Module): |
| 142 | + """VINVL wrapper for classification""" |
| 143 | + |
| 144 | + def __init__( |
| 145 | + self, |
| 146 | + mlp_config: Optional[Dict] = None, |
| 147 | + loss_config: Optional[Dict] = None, |
| 148 | + random_init: bool = False, |
| 149 | + bert_model_name: str = "bert-base-uncased", |
| 150 | + img_feature_dim: int = 2054, |
| 151 | + use_img_layernorm: bool = True, |
| 152 | + img_layer_norm_eps: float = 1e-12, |
| 153 | + *args, |
| 154 | + **kwargs, |
| 155 | + ): |
| 156 | + """VinVL model constructor for classification. |
| 157 | + MLP head is configurable through Dict type. |
| 158 | + Consult the MLP head class for the config options. |
| 159 | +
|
| 160 | + Args: |
| 161 | + mlp_config (Optional[Dict], optional): |
| 162 | + Classifier MLP head config. |
| 163 | + Defaults to {"num_layers": 0}. |
| 164 | + loss_config (Optional[Dict], optional): |
| 165 | + nn.CrossEntropyLoss params dict. |
| 166 | + Defaults to {}. |
| 167 | + random_init (bool, optional): |
| 168 | + Flag to load VinVL bert weights from random_init. |
| 169 | + Defaults to False. |
| 170 | + bert_model_name (str, optional): |
| 171 | + Name for base bert model. |
| 172 | + Used for VinVL base configs and weights. |
| 173 | + Defaults to "bert-base-uncased". |
| 174 | + img_feature_dim (int, optional): |
| 175 | + The size of the VinVL image feature inputs. |
| 176 | + Defaults to 2054. |
| 177 | + use_img_layernorm (bool, optional): |
| 178 | + Flag to use layernorm on image encoding. |
| 179 | + Defaults to True. |
| 180 | + img_layer_norm_eps (float, optional): |
| 181 | + Image layernorm epsilon. Defaults to 1e-12. |
| 182 | + """ |
| 183 | + super().__init__() |
| 184 | + if mlp_config is None: |
| 185 | + mlp_config = {"num_layers": 0} |
| 186 | + if loss_config is None: |
| 187 | + loss_config = {} |
| 188 | + |
| 189 | + self.bert = build_vinvl_base( |
| 190 | + bert_model_name=bert_model_name, |
| 191 | + img_feature_dim=img_feature_dim, |
| 192 | + use_img_layernorm=use_img_layernorm, |
| 193 | + img_layer_norm_eps=img_layer_norm_eps, |
| 194 | + random_init=random_init, |
| 195 | + ) |
| 196 | + self.classifier = MLP(config=mlp_config) |
| 197 | + self.ce_loss = nn.CrossEntropyLoss(**loss_config) |
| 198 | + |
| 199 | + def forward( |
| 200 | + self, |
| 201 | + input_ids: Tensor, |
| 202 | + token_type_ids: Tensor, |
| 203 | + attention_mask: Tensor, |
| 204 | + img_feats: Tensor, |
| 205 | + position_ids: Optional[Tensor] = None, |
| 206 | + labels: Optional[Tensor] = None, |
| 207 | + ) -> Dict[str, Tensor]: |
| 208 | + sequence_output = self.bert( |
| 209 | + input_ids, |
| 210 | + img_feats=img_feats, |
| 211 | + position_ids=position_ids, |
| 212 | + token_type_ids=token_type_ids, |
| 213 | + attention_mask=attention_mask, |
| 214 | + ).last_hidden_state |
| 215 | + logits = self.classifier(sequence_output)["scores"] |
| 216 | + result = {"scores": logits} |
| 217 | + |
| 218 | + if labels is not None: |
| 219 | + ce_loss = self.ce_loss(logits.view(-1, logits.size(1)), labels.view(-1)) |
| 220 | + result["losses"] = {"ce": ce_loss} |
| 221 | + return result |
| 222 | + |
| 223 | + |
| 224 | +class VinVLForPretraining(nn.Module): |
| 225 | + """VINVL wrapper for pretraining |
| 226 | + MLM loss is described in https://arxiv.org/pdf/2004.06165.pdf |
| 227 | + Contrastive loss is an itm loss to guess, |
| 228 | + 0 for a match, |
| 229 | + 1 for a corrupt caption, |
| 230 | + 2 for corrupt image labels |
| 231 | + VinVL trains with object detection labels concatenated with the input text. |
| 232 | + """ |
| 233 | + |
| 234 | + def __init__( |
| 235 | + self, |
| 236 | + mlm_config: Optional[MLM.Config] = None, |
| 237 | + contrast_config: Optional[ThreeWayContrastive.Config] = None, |
| 238 | + random_init: bool = False, |
| 239 | + bert_model_name: str = "bert-base-uncased", |
| 240 | + img_feature_dim: int = 2054, |
| 241 | + use_img_layernorm: bool = True, |
| 242 | + img_layer_norm_eps: float = 1e-12, |
| 243 | + *args, |
| 244 | + **kwargs, |
| 245 | + ): |
| 246 | + """VinVL model constructor for pretraining. |
| 247 | + MLM and Contrastive Loss heads are configurable through Dict types. |
| 248 | + Consult MLM and MLP head classes for their config options. |
| 249 | +
|
| 250 | + Args: |
| 251 | + mlm_config (Optional[MLM.Config], optional): |
| 252 | + Config object for MLM head. |
| 253 | + Defaults to MLM.Config which uses the default MLM configs. |
| 254 | + contrast_config (Optional[ThreeWayContrastive.Config], optional): |
| 255 | + Config object for the 3-way contrastive head. |
| 256 | + Defaults to ThreeWayContrastive.Config which uses a MLP with 3 classes |
| 257 | + random_init (bool, optional): |
| 258 | + Flag to load VinVL bert weights from random_init. |
| 259 | + Defaults to False. |
| 260 | + bert_model_name (str, optional): |
| 261 | + Name for base bert model. |
| 262 | + Used for VinVL base configs and weights. |
| 263 | + Defaults to "bert-base-uncased". |
| 264 | + img_feature_dim (int, optional): |
| 265 | + The size of the VinVL image feature inputs. |
| 266 | + Defaults to 2054. |
| 267 | + use_img_layernorm (bool, optional): |
| 268 | + Flag to use layernorm on image encoding. |
| 269 | + Defaults to True. |
| 270 | + img_layer_norm_eps (float, optional): |
| 271 | + Image layernorm epsilon. Defaults to 1e-12. |
| 272 | + """ |
| 273 | + super().__init__() |
| 274 | + if mlm_config is None: |
| 275 | + mlm_config = asdict(MLM.Config()) |
| 276 | + if contrast_config is None: |
| 277 | + contrast_config = asdict(ThreeWayContrastive.Config()) |
| 278 | + |
| 279 | + self.bert = build_vinvl_base( |
| 280 | + bert_model_name=bert_model_name, |
| 281 | + img_feature_dim=img_feature_dim, |
| 282 | + use_img_layernorm=use_img_layernorm, |
| 283 | + img_layer_norm_eps=img_layer_norm_eps, |
| 284 | + random_init=random_init, |
| 285 | + ) |
| 286 | + self.mlm_head = MLM(config=mlm_config) |
| 287 | + self.ce_loss = nn.CrossEntropyLoss() |
| 288 | + self.contrast_head = ThreeWayContrastive(contrast_config) |
| 289 | + |
| 290 | + def mlm_forward( |
| 291 | + self, |
| 292 | + input_ids_masked: Tensor, |
| 293 | + lm_label_ids: Tensor, |
| 294 | + token_type_ids: Tensor, |
| 295 | + attention_mask: Tensor, |
| 296 | + img_feats: Tensor, |
| 297 | + position_ids: Optional[Tensor] = None, |
| 298 | + ) -> Dict[str, Tensor]: |
| 299 | + |
| 300 | + hidden_layers = self.bert( |
| 301 | + input_ids_masked, |
| 302 | + img_feats=img_feats, |
| 303 | + position_ids=position_ids, |
| 304 | + token_type_ids=token_type_ids, |
| 305 | + attention_mask=attention_mask, |
| 306 | + ).last_hidden_state |
| 307 | + |
| 308 | + mlm_labels = {} |
| 309 | + mlm_labels["text"] = lm_label_ids |
| 310 | + mlm_labels["image"] = torch.full( |
| 311 | + img_feats.shape[:2], |
| 312 | + fill_value=-1, |
| 313 | + dtype=torch.long, |
| 314 | + device=lm_label_ids.device, |
| 315 | + ) |
| 316 | + mlm_labels["combined_labels"] = torch.cat( |
| 317 | + [mlm_labels["text"], mlm_labels["image"]], dim=-1 |
| 318 | + ) |
| 319 | + |
| 320 | + processed_sample_list = SampleList({"mlm_labels": mlm_labels}) |
| 321 | + return self.mlm_head( |
| 322 | + hidden_layers, processed_sample_list=processed_sample_list |
| 323 | + )["losses"] |
| 324 | + |
| 325 | + def contrastive_forward( |
| 326 | + self, |
| 327 | + input_ids: Tensor, |
| 328 | + token_type_ids: Tensor, |
| 329 | + attention_mask: Tensor, |
| 330 | + img_feats: Tensor, |
| 331 | + contrastive_labels: Tensor, |
| 332 | + position_ids: Optional[Tensor] = None, |
| 333 | + ) -> Dict[str, Tensor]: |
| 334 | + |
| 335 | + last_hidden_state = self.bert( |
| 336 | + input_ids, |
| 337 | + img_feats=img_feats, |
| 338 | + position_ids=position_ids, |
| 339 | + token_type_ids=token_type_ids, |
| 340 | + attention_mask=attention_mask, |
| 341 | + ).last_hidden_state |
| 342 | + processed_sample_list = SampleList({"contrastive_labels": contrastive_labels}) |
| 343 | + # contrastive 3-way loss has 3 classes, |
| 344 | + # 0 for a match, 1, 2 for a corrupt caption/image |
| 345 | + # labels respectively |
| 346 | + return self.contrast_head(last_hidden_state, processed_sample_list)["losses"] |
| 347 | + |
| 348 | + def forward( |
| 349 | + self, |
| 350 | + input_ids_masked: Tensor, |
| 351 | + input_ids_corrupt: Tensor, |
| 352 | + lm_label_ids: Tensor, |
| 353 | + contrastive_labels: Tensor, |
| 354 | + token_type_ids: Tensor, |
| 355 | + attention_mask: Tensor, |
| 356 | + token_type_ids_corrupt: Tensor, |
| 357 | + attention_mask_corrupt: Tensor, |
| 358 | + img_feats: Tensor, |
| 359 | + position_ids: Optional[Tensor] = None, |
| 360 | + ) -> Dict[str, Tensor]: |
| 361 | + |
| 362 | + mlm_result = self.mlm_forward( |
| 363 | + input_ids_masked, |
| 364 | + lm_label_ids, |
| 365 | + token_type_ids, |
| 366 | + attention_mask, |
| 367 | + img_feats, |
| 368 | + position_ids, |
| 369 | + ) |
| 370 | + |
| 371 | + contrastive_loss_result = self.contrastive_forward( |
| 372 | + input_ids_corrupt, |
| 373 | + token_type_ids_corrupt, |
| 374 | + attention_mask_corrupt, |
| 375 | + img_feats, |
| 376 | + contrastive_labels, |
| 377 | + position_ids, |
| 378 | + ) |
| 379 | + losses = {**mlm_result, **contrastive_loss_result} |
| 380 | + return {"losses": losses} |
0 commit comments