@@ -355,6 +355,35 @@ def __init__(self, config, *args, **kwargs):
355355 self ._probability = config .get ("mask_probability" , 0 )
356356
357357
358+ def get_pair_text_tokens (item , masked_token_processor ):
359+ if "text" in item :
360+ text_a = item ["text" ]
361+ elif "text_a" in item :
362+ text_a = item ["text_a" ]
363+ else :
364+ text_a = " " .join (item ["tokens" ])
365+
366+ if isinstance (text_a , list ):
367+ text_a = " " .join (text_a )
368+
369+ tokens_a = masked_token_processor .tokenize (text_a )
370+
371+ # 'text_b' can be defined in the dataset preparation
372+ tokens_b = None
373+ if "text_b" in item :
374+ text_b = item ["text_b" ]
375+ if text_b :
376+ tokens_b = masked_token_processor .tokenize (text_b )
377+
378+ masked_token_processor ._truncate_seq_pair (
379+ tokens_a , tokens_b , masked_token_processor ._max_seq_length
380+ )
381+ output = masked_token_processor ._convert_to_indices (
382+ tokens_a , tokens_b , probability = masked_token_processor ._probability
383+ )
384+ return output
385+
386+
358387@registry .register_processor ("vilt_text_tokenizer" )
359388class VILTTextTokenizer (MaskedTokenProcessor ):
360389 def __init__ (self , config , * args , ** kwargs ):
@@ -372,28 +401,93 @@ def __init__(self, config, *args, **kwargs):
372401 self ._probability = config .get ("mask_probability" , 0 )
373402
374403 def __call__ (self , item ):
375- if "text" in item :
376- text_a = item ["text" ]
377- elif "text_a" in item :
378- text_a = item ["text_a" ]
379- else :
380- text_a = " " .join (item ["tokens" ])
404+ output = get_pair_text_tokens (item , self )
405+ output ["text" ] = output ["tokens" ]
406+ return output
381407
382- if isinstance (text_a , list ):
383- text_a = " " .join (text_a )
384408
385- tokens_a = self .tokenize (text_a )
409+ @registry .register_processor ("uniter_text_tokenizer" )
410+ class UNITERTextTokenizer (MaskedTokenProcessor ):
411+ def __init__ (self , config , * args , ** kwargs ):
412+ from transformers import BertTokenizer
386413
387- # 'text_b' can be defined in the dataset preparation
388- tokens_b = None
389- if "text_b" in item :
390- text_b = item ["text_b" ]
391- if text_b :
392- tokens_b = self .tokenize (text_b )
414+ if isinstance (config , str ):
415+ config = {"from_pretrained" : config }
393416
394- self ._truncate_seq_pair (tokens_a , tokens_b , self ._max_seq_length )
395- output = self ._convert_to_indices (
396- tokens_a , tokens_b , probability = self ._probability
417+ from_pretrained_name = config .get ("from_pretrained" , "bert-base-uncased" )
418+ kwargs_dict = dict (kwargs , do_lower_case = "uncased" in from_pretrained_name )
419+ self ._tokenizer = BertTokenizer .from_pretrained (
420+ from_pretrained_name , ** kwargs_dict
397421 )
398- output ["text" ] = output ["tokens" ]
422+ self ._max_seq_length = config .get ("max_seq_length" , 25 )
423+ self ._probability = config .get ("mask_probability" , 0 )
424+
425+ def __call__ (self , item : Dict [str , Any ]):
426+ output = get_pair_text_tokens (item , self )
427+ output ["text" ] = output ["tokens_masked" ]
428+ output ["tokens" ] = output ["tokens_masked" ]
429+ if "is_correct" in item :
430+ output ["is_correct" ] = torch .tensor (
431+ item .get ("is_correct" , True ), dtype = torch .long
432+ )
399433 return output
434+
435+ def _token_transform (
436+ self , tokens : List [str ], tokens_b : Optional [List [str ]] = None
437+ ) -> Tuple [torch .Tensor , int , int , List [str ]]:
438+ tokens = [self ._CLS_TOKEN ] + tokens + [self ._SEP_TOKEN ]
439+ if tokens_b :
440+ tokens += tokens_b + [self ._SEP_TOKEN ]
441+
442+ input_ids = self ._convert_tokens_to_ids (tokens )
443+ token_len = len (input_ids )
444+ token_pad = self ._max_seq_length - token_len
445+ # Zero-pad up to the sequence length.
446+ input_ids += [self ._PAD_TOKEN_ID ] * token_pad
447+ input_ids_tensor = torch .tensor (input_ids , dtype = torch .long )
448+ return input_ids_tensor , token_len , token_pad , tokens
449+
450+ def _convert_to_indices (
451+ self ,
452+ tokens_a : List [str ],
453+ tokens_b : Optional [List [str ]] = None ,
454+ probability : float = 0.15 ,
455+ ) -> Dict [str , torch .Tensor ]:
456+ """
457+ BERT encodes
458+ - single sequence: ``[CLS] X [SEP]``
459+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
460+ """
461+ input_ids_original , _ , _ , _ = self ._token_transform (tokens_a , tokens_b )
462+
463+ tokens_a , label_a = self ._random_word (tokens_a , probability = probability )
464+ segment_ids = [0 ] * (len (tokens_a ) + 2 )
465+
466+ if tokens_b :
467+ tokens_b , label_b = self ._random_word (tokens_b , probability = probability )
468+ lm_label_ids = [- 1 ] + label_a + [- 1 ] + label_b + [- 1 ]
469+ assert len (tokens_b ) > 0
470+ segment_ids += [1 ] * (len (tokens_b ) + 1 )
471+ else :
472+ lm_label_ids = [- 1 ] + label_a + [- 1 ]
473+
474+ input_ids_masked , token_len , token_pad , tokens_masked = self ._token_transform (
475+ tokens_a , tokens_b
476+ )
477+
478+ input_mask = [1 ] * token_len + [0 ] * token_pad
479+ segment_ids += [0 ] * token_pad
480+ lm_label_ids += [- 1 ] * token_pad
481+
482+ input_mask = torch .tensor (input_mask , dtype = torch .long )
483+ segment_ids = torch .tensor (segment_ids , dtype = torch .long )
484+ lm_label_ids = torch .tensor (lm_label_ids , dtype = torch .long )
485+ return {
486+ "input_ids_masked" : input_ids_masked , # specifically for MLM heads
487+ "input_ids" : input_ids_original , # unmasked tokens for CLIP heads
488+ # input_mask is non-padding (1) vs padding (0) mask (not MLM token masking)
489+ "input_mask" : input_mask ,
490+ "segment_ids" : segment_ids ,
491+ "lm_label_ids" : lm_label_ids ,
492+ "tokens_masked" : tokens_masked ,
493+ }
0 commit comments