Skip to content

Commit 9564983

Browse files
[feat] Add UNITER pretraining heads (#1126)
Summary: Pull Request resolved: #1126 Add MRC, MRFR, WRA heads for UNITER pretraining. MRC = Masked Region Classification. MRFR = Masked Region Feature Regression. WRA = Word Region Alignment. Heads forward return a dict with `losses`. These heads can be used as pretraining tasks for other VL models. Details at https://arxiv.org/abs/1909.11740 Test Plan: **Unit tests** Test direct instantiation and forward pass for each head. Instantiation through build() and configs is tested in unit tests in later diffs by the models that use these heads. Tested as part of UNITER pretraining on masked COCO Reviewed By: ebsmothers Differential Revision: D31768455 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: 9b48f81c472cd1859f32bc813484296208e206f5
1 parent 6fba5f6 commit 9564983

File tree

6 files changed

+495
-0
lines changed

6 files changed

+495
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
# Initial version was taken from https://github.com/ChenRocks/UNITER/
4+
# and adapted for MMF.
5+
6+
from typing import Dict
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from mmf.common.registry import registry
11+
from mmf.models.transformers.heads.utils import compute_masked_hidden
12+
from torch import Tensor, nn
13+
14+
15+
@registry.register_transformer_head("mrc")
16+
class MRC(nn.Module):
17+
def __init__(
18+
self,
19+
hidden_size: int = 768,
20+
loss_name: str = "mrc_loss",
21+
ignore_index: int = -1,
22+
mrc_label_key: str = "region_class",
23+
mrc_mask_key: str = "image_region_mask",
24+
label_dim: int = 1601,
25+
eps: float = 1e-12,
26+
use_kl: bool = True,
27+
*args,
28+
**kwargs,
29+
):
30+
31+
super().__init__()
32+
self.loss_name = loss_name
33+
self.ignore_index = ignore_index
34+
self.mrc_label_key = mrc_label_key
35+
self.mrc_mask_key = mrc_mask_key
36+
self.use_kl = use_kl
37+
38+
# Head modules
39+
self.region_classifier = nn.Sequential(
40+
nn.Linear(hidden_size, hidden_size),
41+
nn.GELU(),
42+
nn.LayerNorm(hidden_size, eps=eps),
43+
nn.Linear(hidden_size, label_dim),
44+
)
45+
46+
def forward(
47+
self,
48+
sequence_output: Tensor,
49+
processed_sample_list: Dict[str, Dict[str, Tensor]],
50+
) -> Dict[str, Dict[str, Tensor]]:
51+
52+
output_dict = {}
53+
assert (
54+
self.mrc_label_key in processed_sample_list
55+
and processed_sample_list[self.mrc_label_key] is not None
56+
), (
57+
f"MRC pretraining requires {self.mrc_label_key} to be in sample "
58+
+ "list with value not None."
59+
)
60+
# (bs*num_feat, label_dim) Look at unit test for example usage!
61+
region_labels = processed_sample_list[self.mrc_label_key]
62+
63+
assert (
64+
self.mrc_mask_key in processed_sample_list
65+
and processed_sample_list[self.mrc_mask_key] is not None
66+
), (
67+
f"MRC pretraining requires {self.mrc_mask_key} to be in sample "
68+
+ "list with value not None."
69+
)
70+
# (bs, num_feat)
71+
image_region_masks = processed_sample_list[self.mrc_mask_key]
72+
73+
masked_output = compute_masked_hidden(sequence_output, image_region_masks)
74+
prediction_soft_label = self.region_classifier(masked_output)
75+
if self.use_kl:
76+
prediction_soft_label = F.log_softmax(prediction_soft_label, dim=-1)
77+
mrc_loss = F.kl_div(
78+
prediction_soft_label, region_labels, reduction="batchmean"
79+
)
80+
else:
81+
# background class should not be the target
82+
label_targets = torch.max(region_labels[:, 1:], dim=-1)[1] + 1
83+
mrc_loss = F.cross_entropy(
84+
prediction_soft_label,
85+
label_targets,
86+
ignore_index=self.ignore_index,
87+
reduction="mean",
88+
)
89+
90+
output_dict["losses"] = {}
91+
output_dict["losses"][self.loss_name] = mrc_loss
92+
return output_dict
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
# Initial version was taken from https://github.com/ChenRocks/UNITER/
4+
# and adapted for MMF.
5+
6+
from typing import Dict
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from mmf.common.registry import registry
11+
from mmf.models.transformers.heads.utils import compute_masked_hidden
12+
from torch import Tensor, nn
13+
14+
15+
@registry.register_transformer_head("mrfr")
16+
class MRFR(nn.Module):
17+
"""
18+
Masked Region Feature Regression transformer head,
19+
From uniter paper https://arxiv.org/pdf/1909.11740.pdf
20+
For an example usage take a look at the unit test.
21+
"""
22+
23+
def __init__(
24+
self,
25+
img_embedding_weight: nn.Parameter,
26+
hidden_size: int = 768,
27+
loss_name: str = "mrfr_loss",
28+
mrfr_target_key: str = "mrfr_region_target",
29+
mrfr_mask_key: str = "mrfr_region_mask",
30+
img_dim: int = 2048,
31+
eps: float = 1e-12,
32+
*args,
33+
**kwargs,
34+
):
35+
super().__init__()
36+
self.loss_name = loss_name
37+
self.mrfr_target_key = mrfr_target_key
38+
self.mrfr_mask_key = mrfr_mask_key
39+
40+
# Head modules
41+
assert img_embedding_weight is not None and tuple(
42+
img_embedding_weight.shape
43+
) == (hidden_size, img_dim), (
44+
"MRFR head requires 'img_embedding_weight' with shape "
45+
+ f"({hidden_size}, {img_dim})."
46+
)
47+
48+
self.linear_proj_weight = img_embedding_weight
49+
self.linear_proj_bias = nn.Parameter(torch.zeros(img_dim))
50+
51+
self.feat_regress = nn.Sequential(
52+
nn.Linear(hidden_size, hidden_size),
53+
nn.GELU(),
54+
nn.LayerNorm(hidden_size, eps=eps),
55+
)
56+
57+
def forward(
58+
self,
59+
sequence_output: Tensor,
60+
processed_sample_list: Dict[str, Dict[str, Tensor]],
61+
) -> Dict[str, Dict[str, Tensor]]:
62+
63+
output_dict = {}
64+
65+
assert (
66+
self.mrfr_target_key in processed_sample_list
67+
and processed_sample_list[self.mrfr_target_key] is not None
68+
), (
69+
f"MRFR pretraining requires {self.mrfr_target_key} to be in sample "
70+
+ "list with value not None."
71+
)
72+
# (bs*num_feat, img_dim) Look at unit test for example usage!
73+
feat_targets = processed_sample_list[self.mrfr_target_key]
74+
75+
assert (
76+
self.mrfr_mask_key in processed_sample_list
77+
and processed_sample_list[self.mrfr_mask_key] is not None
78+
), (
79+
f"MRFR pretraining requires {self.mrfr_mask_key} to be in sample "
80+
+ "list with value not None."
81+
)
82+
# (bs, num_feat)
83+
image_region_masks = processed_sample_list[self.mrfr_mask_key]
84+
85+
masked_output = compute_masked_hidden(sequence_output, image_region_masks)
86+
hidden_states = self.feat_regress(masked_output)
87+
prediction_feat = F.linear(
88+
hidden_states, self.linear_proj_weight.t(), self.linear_proj_bias
89+
)
90+
mrfr_loss = F.mse_loss(prediction_feat, feat_targets, reduction="mean")
91+
92+
output_dict["losses"] = {}
93+
output_dict["losses"][self.loss_name] = mrfr_loss
94+
return output_dict

mmf/models/transformers/heads/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,16 @@ def _process_head_output(
164164
)
165165
output = self.losses[loss_name](sample_list, {"scores": logits})
166166
return {"losses": output, "scores": logits}
167+
168+
169+
def compute_masked_hidden(hidden: Tensor, mask: Tensor) -> Tensor:
170+
"""Get only the masked region.
171+
172+
hidden: tensor, dim (bs, num_feat, feat_dim)
173+
mask: bool tensor, dim (bs, num_feat)
174+
Returns a tensor of dim (bs * num_feat_unmasked, feat_dim),
175+
containing the features in hidden that are True in the mask tensor.
176+
"""
177+
mask = mask.unsqueeze(-1).expand_as(hidden)
178+
hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1))
179+
return hidden_masked
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
# Initial version was taken from https://github.com/ChenRocks/UNITER/
4+
# and adapted for MMF.
5+
6+
from typing import Dict
7+
8+
from mmf.common.registry import registry
9+
from mmf.modules.ot import optimal_transport_dist
10+
from torch import Tensor, nn
11+
12+
13+
@registry.register_transformer_head("wra")
14+
class WRA(nn.Module):
15+
"""
16+
Word Region Alignment from UNITER.
17+
Optimal Transport (OT) distance between text and image
18+
features is used to optimize for WRA.
19+
OT transport plan (T) is approximated through IPOT.
20+
"""
21+
22+
def __init__(
23+
self,
24+
loss_name: str = "wra_loss",
25+
ot_inputs_key: str = "wra_info",
26+
wra_label_key: str = "is_correct",
27+
*args,
28+
**kwargs,
29+
):
30+
super().__init__()
31+
self.loss_name = loss_name
32+
self.ot_inputs_key = ot_inputs_key
33+
self.wra_label_key = wra_label_key
34+
35+
def forward(
36+
self,
37+
sequence_output: Tensor,
38+
processed_sample_list: Dict[str, Dict[str, Tensor]],
39+
) -> Dict[str, Dict[str, Tensor]]:
40+
41+
output_dict = {}
42+
43+
assert (
44+
self.ot_inputs_key in processed_sample_list
45+
and processed_sample_list[self.ot_inputs_key] is not None
46+
), (
47+
f"WRA pretraining requires {self.ot_inputs_key} to be in sample "
48+
+ "list with value not None."
49+
)
50+
ot_inputs = processed_sample_list[self.ot_inputs_key]
51+
52+
assert (
53+
ot_inputs.get("txt_pad") is not None
54+
and ot_inputs.get("img_pad") is not None
55+
), (
56+
"WRA pretraining requires 'txt_pad', and 'img_pad' to be in "
57+
+ f"'processed_sample_list[{self.ot_inputs_key}]' with"
58+
+ " values not None."
59+
)
60+
assert processed_sample_list.get(self.wra_label_key) is not None, (
61+
f"WRA pretraining requires {self.wra_label_key} to be in sample "
62+
+ "list with value not None."
63+
)
64+
65+
ctx_emb = sequence_output
66+
tl = processed_sample_list["input_ids"].size(1)
67+
il = processed_sample_list["image_feat"].size(1)
68+
txt_emb = ctx_emb[:, :tl, :]
69+
img_emb = ctx_emb[:, tl : tl + il, :]
70+
71+
txt_pad = ot_inputs["txt_pad"].bool()
72+
img_pad = ot_inputs["img_pad"].bool()
73+
itm_labels = processed_sample_list[self.wra_label_key]
74+
# NOTE: run in fp32 for stability
75+
ot_dist = optimal_transport_dist(
76+
txt_emb.float(), img_emb.float(), txt_pad, img_pad
77+
).to(txt_emb)
78+
ot_pos = ot_dist.masked_select(itm_labels == 1)
79+
ot_neg = ot_dist.masked_select(itm_labels == 0)
80+
ot_loss = (ot_pos.sum() - ot_neg.sum()) / (ot_pos.size(0) + ot_neg.size(0))
81+
82+
output_dict["losses"] = {}
83+
output_dict["losses"][self.loss_name] = ot_loss
84+
return output_dict

mmf/modules/ot.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
"""
4+
Initial version was taken from https://github.com/ChenRocks/UNITER/
5+
Licensed under the MIT license.
6+
7+
Wasserstein Distance (Optimal Transport)
8+
"""
9+
10+
import torch
11+
from torch import Tensor
12+
from torch.nn import functional as F
13+
14+
15+
def cost_matrix_cosine(x: Tensor, y: Tensor, eps: float = 1e-5) -> Tensor:
16+
"""Compute cosine distance across every pairs of x, y (batched)
17+
[B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]"""
18+
assert x.dim() == y.dim()
19+
assert x.size(0) == y.size(0)
20+
assert x.size(2) == y.size(2)
21+
x_norm = F.normalize(x, p=2, dim=-1, eps=eps)
22+
y_norm = F.normalize(y, p=2, dim=-1, eps=eps)
23+
cosine_sim = x_norm.matmul(y_norm.transpose(1, 2))
24+
cosine_dist = 1 - cosine_sim
25+
return cosine_dist
26+
27+
28+
def trace(x: Tensor) -> Tensor:
29+
"""Compute trace of input tensor (batched)"""
30+
b, m, n = x.size()
31+
assert m == n
32+
mask = torch.eye(n, dtype=torch.bool, device=x.device).unsqueeze(0).expand_as(x)
33+
trace = x.masked_select(mask).contiguous().view(b, n).sum(dim=-1, keepdim=False)
34+
return trace
35+
36+
37+
@torch.no_grad()
38+
def ipot(
39+
C: Tensor,
40+
x_len: int,
41+
x_pad: Tensor,
42+
y_len: int,
43+
y_pad: Tensor,
44+
joint_pad: Tensor,
45+
beta: float,
46+
iteration: int,
47+
k: int,
48+
) -> Tensor:
49+
"""[B, M, N], [B], [B, M], [B], [B, N], [B, M, N]"""
50+
b, m, n = C.size()
51+
sigma = torch.ones(b, m, dtype=C.dtype, device=C.device) / x_len.unsqueeze(1)
52+
T = torch.ones(b, n, m, dtype=C.dtype, device=C.device)
53+
A = torch.exp(-C.transpose(1, 2) / beta)
54+
55+
# mask padded positions
56+
sigma.masked_fill_(x_pad, 0)
57+
joint_pad = joint_pad.transpose(1, 2)
58+
T.masked_fill_(joint_pad, 0)
59+
A.masked_fill_(joint_pad, 0)
60+
61+
# broadcastable lengths
62+
x_len = x_len.unsqueeze(1).unsqueeze(2)
63+
y_len = y_len.unsqueeze(1).unsqueeze(2)
64+
65+
# mask to zero out padding in delta and sigma
66+
x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1)
67+
y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1)
68+
69+
for _ in range(iteration):
70+
Q = A * T # bs * n * m
71+
sigma = sigma.view(b, m, 1)
72+
for _ in range(k):
73+
delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask)
74+
sigma = 1 / (x_len * delta.matmul(Q) + x_mask)
75+
T = delta.view(b, n, 1) * Q * sigma
76+
T.masked_fill_(joint_pad, 0)
77+
return T
78+
79+
80+
def optimal_transport_dist(
81+
txt_emb: Tensor,
82+
img_emb: Tensor,
83+
txt_pad: Tensor,
84+
img_pad: Tensor,
85+
beta: float = 0.5,
86+
iteration: int = 50,
87+
k: int = 1,
88+
) -> Tensor:
89+
"""[B, M, D], [B, N, D], [B, M], [B, N]"""
90+
cost = cost_matrix_cosine(txt_emb, img_emb)
91+
# mask the padded inputs
92+
joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
93+
cost.masked_fill_(joint_pad, 0)
94+
95+
txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to(dtype=cost.dtype)
96+
img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to(dtype=cost.dtype)
97+
98+
T = ipot(
99+
cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, beta, iteration, k
100+
)
101+
distance = trace(cost.matmul(T.detach()))
102+
return distance

0 commit comments

Comments
 (0)