You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried using the train_mlm with the GTE models from huggingface, specifically the Alibaba-NLP/gte-base-en-v1.5 and thenlper/gte-base models. I was hitting the following error for the Alibaba-NLP/gte-base-en-v1.5 model:
Traceback (most recent call last):
File "C:\...\mlm_train.py", line 169, in <module>
trainer.train()
[..]
File "C:\...\.cache\huggingface\modules\transformers_modules\Alibaba-NLP\new-impl\40ced75c3017eb27626c9d4ea981bde21a2662f4\modeling.py", line 1060, in forward
mask = attention_mask.bool()
^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'bool'
With the thenlper/gte-base model, it runs but gives the following warning:
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
I was able to fix these problems by creating a custom data collator that adds the attention masks to the batch, and replacing the use of DataCollatorForWholeWordMask with the new one:
import torch
from torch.nn.utils.rnn import pad_sequence
[..]
class DataCollatorForWholeWordMaskWithAttentionMasks(DataCollatorForWholeWordMask):
def __call__(self, examples):
batch = super().__call__(examples)
attention_masks = [torch.tensor(example['attention_mask']) for example in examples]
padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
batch['attention_mask'] = padded_attention_masks
return batch
[..]
if do_whole_word_mask:
data_collator = DataCollatorForWholeWordMaskWithAttentionMasks(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
I'm unsure if this is the correct fix, though.
The text was updated successfully, but these errors were encountered:
I tried using the train_mlm with the GTE models from huggingface, specifically the Alibaba-NLP/gte-base-en-v1.5 and thenlper/gte-base models. I was hitting the following error for the Alibaba-NLP/gte-base-en-v1.5 model:
With the thenlper/gte-base model, it runs but gives the following warning:
I was able to fix these problems by creating a custom data collator that adds the attention masks to the batch, and replacing the use of DataCollatorForWholeWordMask with the new one:
I'm unsure if this is the correct fix, though.
The text was updated successfully, but these errors were encountered: