Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_mlm.py does not work with GTE v1.5 models. Requires attention masks. #2971

Open
JZacaroli opened this issue Oct 4, 2024 · 0 comments

Comments

@JZacaroli
Copy link

JZacaroli commented Oct 4, 2024

I tried using the train_mlm with the GTE models from huggingface, specifically the Alibaba-NLP/gte-base-en-v1.5 and thenlper/gte-base models. I was hitting the following error for the Alibaba-NLP/gte-base-en-v1.5 model:

Traceback (most recent call last):
  File "C:\...\mlm_train.py", line 169, in <module>
    trainer.train()
  [..]
  File "C:\...\.cache\huggingface\modules\transformers_modules\Alibaba-NLP\new-impl\40ced75c3017eb27626c9d4ea981bde21a2662f4\modeling.py", line 1060, in forward
    mask = attention_mask.bool()
           ^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'bool'

With the thenlper/gte-base model, it runs but gives the following warning:

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.

I was able to fix these problems by creating a custom data collator that adds the attention masks to the batch, and replacing the use of DataCollatorForWholeWordMask with the new one:

import torch
from torch.nn.utils.rnn import pad_sequence

[..]

class DataCollatorForWholeWordMaskWithAttentionMasks(DataCollatorForWholeWordMask):
    def __call__(self, examples):
        batch = super().__call__(examples)
        attention_masks = [torch.tensor(example['attention_mask']) for example in examples]
        padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
        batch['attention_mask'] = padded_attention_masks
        return batch

[..]

if do_whole_word_mask:
    data_collator = DataCollatorForWholeWordMaskWithAttentionMasks(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)

I'm unsure if this is the correct fix, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant