Skip to content

Commit b672a74

Browse files
[feat] Add UNITER text processor (#1132)
Summary: Pull Request resolved: #1132 Add uniter_text_tokenizer which adds 'input_ids_masked' to the sample list seperate from 'input_ids'. Test Plan: **Unit tests** Tests for the construction of the processor and processor output. We assert that the tokens are correct and contain the enhanced fields required for UNITER. Reviewed By: ebsmothers Differential Revision: D31865997 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: cf8f8b312aebe39bdfe3d5f5516831e99e27888b
1 parent 9564983 commit b672a74

File tree

2 files changed

+217
-19
lines changed

2 files changed

+217
-19
lines changed

mmf/datasets/processors/bert_processors.py

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,35 @@ def __init__(self, config, *args, **kwargs):
355355
self._probability = config.get("mask_probability", 0)
356356

357357

358+
def get_pair_text_tokens(item, masked_token_processor):
359+
if "text" in item:
360+
text_a = item["text"]
361+
elif "text_a" in item:
362+
text_a = item["text_a"]
363+
else:
364+
text_a = " ".join(item["tokens"])
365+
366+
if isinstance(text_a, list):
367+
text_a = " ".join(text_a)
368+
369+
tokens_a = masked_token_processor.tokenize(text_a)
370+
371+
# 'text_b' can be defined in the dataset preparation
372+
tokens_b = None
373+
if "text_b" in item:
374+
text_b = item["text_b"]
375+
if text_b:
376+
tokens_b = masked_token_processor.tokenize(text_b)
377+
378+
masked_token_processor._truncate_seq_pair(
379+
tokens_a, tokens_b, masked_token_processor._max_seq_length
380+
)
381+
output = masked_token_processor._convert_to_indices(
382+
tokens_a, tokens_b, probability=masked_token_processor._probability
383+
)
384+
return output
385+
386+
358387
@registry.register_processor("vilt_text_tokenizer")
359388
class VILTTextTokenizer(MaskedTokenProcessor):
360389
def __init__(self, config, *args, **kwargs):
@@ -372,28 +401,93 @@ def __init__(self, config, *args, **kwargs):
372401
self._probability = config.get("mask_probability", 0)
373402

374403
def __call__(self, item):
375-
if "text" in item:
376-
text_a = item["text"]
377-
elif "text_a" in item:
378-
text_a = item["text_a"]
379-
else:
380-
text_a = " ".join(item["tokens"])
404+
output = get_pair_text_tokens(item, self)
405+
output["text"] = output["tokens"]
406+
return output
381407

382-
if isinstance(text_a, list):
383-
text_a = " ".join(text_a)
384408

385-
tokens_a = self.tokenize(text_a)
409+
@registry.register_processor("uniter_text_tokenizer")
410+
class UNITERTextTokenizer(MaskedTokenProcessor):
411+
def __init__(self, config, *args, **kwargs):
412+
from transformers import BertTokenizer
386413

387-
# 'text_b' can be defined in the dataset preparation
388-
tokens_b = None
389-
if "text_b" in item:
390-
text_b = item["text_b"]
391-
if text_b:
392-
tokens_b = self.tokenize(text_b)
414+
if isinstance(config, str):
415+
config = {"from_pretrained": config}
393416

394-
self._truncate_seq_pair(tokens_a, tokens_b, self._max_seq_length)
395-
output = self._convert_to_indices(
396-
tokens_a, tokens_b, probability=self._probability
417+
from_pretrained_name = config.get("from_pretrained", "bert-base-uncased")
418+
kwargs_dict = dict(kwargs, do_lower_case="uncased" in from_pretrained_name)
419+
self._tokenizer = BertTokenizer.from_pretrained(
420+
from_pretrained_name, **kwargs_dict
397421
)
398-
output["text"] = output["tokens"]
422+
self._max_seq_length = config.get("max_seq_length", 25)
423+
self._probability = config.get("mask_probability", 0)
424+
425+
def __call__(self, item: Dict[str, Any]):
426+
output = get_pair_text_tokens(item, self)
427+
output["text"] = output["tokens_masked"]
428+
output["tokens"] = output["tokens_masked"]
429+
if "is_correct" in item:
430+
output["is_correct"] = torch.tensor(
431+
item.get("is_correct", True), dtype=torch.long
432+
)
399433
return output
434+
435+
def _token_transform(
436+
self, tokens: List[str], tokens_b: Optional[List[str]] = None
437+
) -> Tuple[torch.Tensor, int, int, List[str]]:
438+
tokens = [self._CLS_TOKEN] + tokens + [self._SEP_TOKEN]
439+
if tokens_b:
440+
tokens += tokens_b + [self._SEP_TOKEN]
441+
442+
input_ids = self._convert_tokens_to_ids(tokens)
443+
token_len = len(input_ids)
444+
token_pad = self._max_seq_length - token_len
445+
# Zero-pad up to the sequence length.
446+
input_ids += [self._PAD_TOKEN_ID] * token_pad
447+
input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
448+
return input_ids_tensor, token_len, token_pad, tokens
449+
450+
def _convert_to_indices(
451+
self,
452+
tokens_a: List[str],
453+
tokens_b: Optional[List[str]] = None,
454+
probability: float = 0.15,
455+
) -> Dict[str, torch.Tensor]:
456+
"""
457+
BERT encodes
458+
- single sequence: ``[CLS] X [SEP]``
459+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
460+
"""
461+
input_ids_original, _, _, _ = self._token_transform(tokens_a, tokens_b)
462+
463+
tokens_a, label_a = self._random_word(tokens_a, probability=probability)
464+
segment_ids = [0] * (len(tokens_a) + 2)
465+
466+
if tokens_b:
467+
tokens_b, label_b = self._random_word(tokens_b, probability=probability)
468+
lm_label_ids = [-1] + label_a + [-1] + label_b + [-1]
469+
assert len(tokens_b) > 0
470+
segment_ids += [1] * (len(tokens_b) + 1)
471+
else:
472+
lm_label_ids = [-1] + label_a + [-1]
473+
474+
input_ids_masked, token_len, token_pad, tokens_masked = self._token_transform(
475+
tokens_a, tokens_b
476+
)
477+
478+
input_mask = [1] * token_len + [0] * token_pad
479+
segment_ids += [0] * token_pad
480+
lm_label_ids += [-1] * token_pad
481+
482+
input_mask = torch.tensor(input_mask, dtype=torch.long)
483+
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
484+
lm_label_ids = torch.tensor(lm_label_ids, dtype=torch.long)
485+
return {
486+
"input_ids_masked": input_ids_masked, # specifically for MLM heads
487+
"input_ids": input_ids_original, # unmasked tokens for CLIP heads
488+
# input_mask is non-padding (1) vs padding (0) mask (not MLM token masking)
489+
"input_mask": input_mask,
490+
"segment_ids": segment_ids,
491+
"lm_label_ids": lm_label_ids,
492+
"tokens_masked": tokens_masked,
493+
}

tests/datasets/test_bert_processors.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,107 @@ def test_vilt_tokenizer(self):
185185

186186
# Test [MASK] token is present
187187
self.assertTrue(103 in results["input_ids"])
188+
189+
def test_uniter_tokenizer(self):
190+
from mmf.datasets.processors.bert_processors import UNITERTextTokenizer
191+
192+
test_utils.setup_proxy()
193+
config = OmegaConf.create(
194+
{
195+
"tokenizer_config": {
196+
"type": "bert-base-uncased",
197+
"params": {"do_lower_case": True},
198+
},
199+
"mask_probability": 0.5,
200+
"max_seq_length": 128,
201+
}
202+
)
203+
204+
processor = UNITERTextTokenizer(config)
205+
206+
# Test normal caption
207+
arg = {"text": "This will be a test of tokens?"}
208+
results = processor(arg)
209+
expected_input_ids = torch.zeros(128, dtype=torch.long)
210+
expected_input_ids[:11] = torch.tensor(
211+
[101, 2023, 2097, 2022, 1037, 3231, 1997, 19204, 2015, 1029, 102],
212+
dtype=torch.long,
213+
)
214+
expected_segment_ids = torch.zeros(128, dtype=torch.long)
215+
expected_masks = torch.zeros(128, dtype=torch.long)
216+
expected_masks[:11] = 1
217+
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
218+
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
219+
self.assertTrue(torch.equal(results["input_mask"], expected_masks))
220+
self.assertTrue("input_ids_masked" in results)
221+
self.assertEqual(results["input_ids"].shape, results["input_ids_masked"].shape)
222+
223+
# Test empty caption
224+
arg = {"text": ""}
225+
results = processor(arg)
226+
expected_input_ids = torch.zeros(128, dtype=torch.long)
227+
expected_input_ids[:2] = torch.tensor([101, 102], dtype=torch.long)
228+
expected_segment_ids = torch.zeros(128, dtype=torch.long)
229+
expected_masks = torch.zeros(128, dtype=torch.long)
230+
expected_masks[:2] = 1
231+
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
232+
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
233+
self.assertTrue(torch.equal(results["input_mask"], expected_masks))
234+
self.assertTrue("input_ids_masked" in results)
235+
self.assertEqual(results["input_ids"].shape, results["input_ids_masked"].shape)
236+
237+
# Test long caption
238+
arg = {"text": "I am working for facebook " * 100} # make a long sentence
239+
results = processor(arg)
240+
expected_input_ids = [1045, 2572, 2551, 2005, 9130] * 100
241+
expected_input_ids.insert(0, 101) # [CLS]
242+
expected_input_ids = expected_input_ids[:128]
243+
expected_input_ids[-1] = 102 # [SEP]
244+
expected_input_ids = torch.tensor(expected_input_ids, dtype=torch.long)
245+
expected_segment_ids = torch.zeros(128, dtype=torch.long)
246+
expected_masks = torch.ones(128, dtype=torch.long)
247+
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
248+
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
249+
self.assertTrue(torch.equal(results["input_mask"], expected_masks))
250+
self.assertTrue("input_ids_masked" in results)
251+
self.assertEqual(results["input_ids"].shape, results["input_ids_masked"].shape)
252+
253+
# Test two captions
254+
arg = {
255+
"text_a": "This will be a test of tokens?",
256+
"text_b": "I am working for facebook",
257+
}
258+
results = processor(arg)
259+
expected_input_ids = torch.zeros(128, dtype=torch.long)
260+
expected_input_ids[:17] = torch.tensor(
261+
[101, 2023, 2097, 2022, 1037, 3231, 1997, 19204, 2015, 1029, 102]
262+
+ [1045, 2572, 2551, 2005, 9130, 102],
263+
dtype=torch.long,
264+
)
265+
expected_segment_ids = torch.zeros(128, dtype=torch.long)
266+
expected_segment_ids[11:17] = 1
267+
expected_masks = torch.zeros(128, dtype=torch.long)
268+
expected_masks[:17] = 1
269+
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
270+
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
271+
self.assertTrue(torch.equal(results["input_mask"], expected_masks))
272+
self.assertTrue("input_ids_masked" in results)
273+
self.assertEqual(results["input_ids"].shape, results["input_ids_masked"].shape)
274+
275+
# Test masked caption
276+
processor._probability = 1.0
277+
arg = {"text": "This will be a test of tokens?"}
278+
results = processor(arg)
279+
expected_input_ids = torch.zeros(128, dtype=torch.long)
280+
expected_input_ids[:11] = torch.tensor(
281+
[101, 2023, 2097, 2022, 1037, 3231, 1997, 19204, 2015, 1029, 102],
282+
dtype=torch.long,
283+
)
284+
expected_segment_ids = torch.zeros(128, dtype=torch.long)
285+
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
286+
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
287+
self.assertTrue("input_ids_masked" in results)
288+
self.assertEqual(results["input_ids"].shape, results["input_ids_masked"].shape)
289+
290+
# Test [MASK] token is present
291+
self.assertTrue(103 in results["input_ids_masked"])

0 commit comments

Comments
 (0)