Skip to content

fix wrong 'cls' masking for bigbird qa model output#13143

Merged
patrickvonplaten merged 1 commit intohuggingface:masterfrom
donggyukimc:fix-wrong-masking-bigbird-qa
Sep 1, 2021
Merged

fix wrong 'cls' masking for bigbird qa model output#13143
patrickvonplaten merged 1 commit intohuggingface:masterfrom
donggyukimc:fix-wrong-masking-bigbird-qa

Conversation

@donggyukimc
Copy link
Copy Markdown
Contributor

@donggyukimc donggyukimc commented Aug 16, 2021

What does this PR do?

Currently, the bigbird QA model masks out (assign very small value < -1e6) all logits before context tokens as follows.

tokens : ['[CLS]', '▁How', '▁old', '▁are', '▁you', '?', '[SEP]', '▁I', "'m", '▁twenty', '▁years', '▁old', '.']
input_ids : [65, 1475, 1569, 490, 446, 131, 66, 415, 1202, 8309, 913, 1569, 114]
attention_mask : [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
token_type_ids : [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]

start_logits:  
[-1.00000231e+06 -1.00000294e+06 -1.00000794e+06 -1.00000525e+06
 -1.00000344e+06 -1.00000288e+06 -9.99994312e+05 -2.53751278e+00
 -7.34928894e+00  4.26531649e+00 -6.21708155e+00 -8.17963409e+00
 -6.25242186e+00]
end_logits:  
[-1.00000169e+06 -1.00000869e+06 -1.00000731e+06 -1.00001088e+06
 -1.00000856e+06 -1.00000781e+06 -9.99996375e+05 -9.58227539e+00
 -9.81797123e+00 -2.89585280e+00  1.97710574e+00 -9.89597499e-01
 -5.21932888e+00]

As you can see, it also masks out the logits from [CLS] token. This is because the following function makes question masks based on the position of the first [SEP] token.

def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int):

However, this is the wrong mechanism because [CLS] token is used for the prediction of "unanswerable question" in many QA models.

So, I simply change the code so that the masking on [CLS] token is disabled right after the creation of token_type_ids.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten
Copy link
Copy Markdown
Contributor

Hey @donggyukimc,

Thanks for your PR - this makes sense to me. Do you by any chance have a reference to the original code / paper that shows that the original CLS token should not be masked out?

Also cc-ing our expert on BigBird here @vasudevgupta7

@thevasudevgupta
Copy link
Copy Markdown
Contributor

@donggyukimc, I am little unsure about this. In the original code also, they are masking out everything till first [SEP] (see this).

If we don't mask the CLS token, then there is a possibility that start_token will point to CLS but end_token will point to some token in a sequence and hence final answer will have question also. I think cases corresponding to whether answer is present (or not) should be handled by putting a classifier over the pooler layer instead (something like this). If we make the model point start_token & end_token to CLS during training, it usually leads to infinite/nan loss during training but classifier approach works well.

Correct me if you feel I am wrong somewhere.

@donggyukimc
Copy link
Copy Markdown
Contributor Author

@vasudevgupta7, Thank you for your comment.

I bring the QA models from other architectures (BERT, ROBERTA)

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

Even though both of them do not apply any mask on predictions for CLS (and also questions), they can be trained without the problems on loss. (actually, CLS shouldn't be masked out because they predict unanswerable probability from CLS)

As you can see in, squad_metrics.py, the QA evaluation pipeline in transformers library,

if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue

it directly computes unanswerable probability from same MLP logit outputs with answerable spans.
feature_null_score = result.start_logits[0] + result.end_logits[0]

One of your our concerns (there is a possibility that start_token will point to CLS but end_token will point to some token in a sequence and hence final answer will have question also) will be prevented in this part.

if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue

because the positions of questions tokens not exists in feature.token_to_orig_map.

Your suggestion using a separate MLP to predict unanswerable probability will also do the work, but you have to use different evaluation code except for squad_metrics.py.

Actually, this is how i found the problem, i got wrong prediction results when i used bigbirdQA model + squad_metrics.py

In my opinion, it is better to use the same prediction mechanism in order to keep compatibility between other QA model architectures and the QA evaluation pipeline in transformers library.

I'd like to hear your opinion on this.

Thank you for your thoughtful comment again, @vasudevgupta7.

@donggyukimc
Copy link
Copy Markdown
Contributor Author

any thoughts on my opinion? @patrickvonplaten @vasudevgupta7

@thevasudevgupta
Copy link
Copy Markdown
Contributor

Hey @donggyukimc, so sorry I missed your comment earlier. As you pointed out about BERT like models, I think it's fine to unmask CLS token to mantain consistency with other models. So, this PR looks alright to me.

@patrickvonplaten
Copy link
Copy Markdown
Contributor

Awesome merging it then!

@patrickvonplaten patrickvonplaten merged commit ba1b3db into huggingface:master Sep 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants