Skip to content

Commit

Permalink
Merge pull request #30 from lightonai/fix_prefix
Browse files Browse the repository at this point in the history
Adding a proper function to add the prefixes
  • Loading branch information
NohTow authored Aug 9, 2024
2 parents d76f1d1 + 17c381e commit 6c3b5c5
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def __init__(
self.model_card_data.register_model(self)

# this will add the query and document prefix to the tokenizer vocab if they are not already there and resize the embeddings accordingly
# self.tokenizer.add_tokens([self.query_prefix, self.document_prefix])
# self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))
self.tokenizer.add_tokens([self.query_prefix, self.document_prefix])
self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))

self.document_prefix_id = self.tokenizer.convert_tokens_to_ids(
self.document_prefix
Expand Down Expand Up @@ -1126,6 +1126,12 @@ def get_max_seq_length(self) -> int | None:

return None

def insert_prefix_token(self, tensor, prefix_id):
prefix_tensor = torch.full(
(tensor.size(0), 1), prefix_id, dtype=tensor.dtype, device=tensor.device
)
return torch.cat([tensor[:, :1], prefix_tensor, tensor[:, 1:]], dim=1)

def tokenize(
self,
texts: Union[list[str], list[dict], list[tuple[str, str]]],
Expand All @@ -1142,17 +1148,22 @@ def tokenize(
dict[str, torch.Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids",
"attention_mask", and "token_type_ids".
"""
# TODO: add the skiplist
# Add placeholder for the document/query prefix
texts = [". " + text for text in texts]
if is_query:
# TODO: This is a hack to asymetrically set the max_seq_length for the query/document, change it once the Transformer module tokenize function expose a max_length argument
self._first_module().max_seq_length = self.query_length
features = self._first_module().tokenize(texts, padding="max_length")
# Remplace the second token by the query prefix
# TODO: Do this in a prettier way. Okay we cannot directly add the text in the string, but this is not robust (multiple ids, ...)
# e.g : # features["input_ids"] = torch.cat((features["input_ids"][:, :1], self.document_query_id, ids[:, 1:]), dim=1) ; features["attention_mask"] = torch.cat((features["attention_mask"][:, :1], torch.ones((features["attention_mask"].shape[0], 1), dtype=torch.int8), features["attention_mask"][:, 1:]), dim=1)
features["input_ids"][:, 1] = self.query_prefix_id
# Create a new tensor with the query prefix ID inserted after the first token
features["input_ids"] = self.insert_prefix_token(
features["input_ids"], self.query_prefix_id
)
# Update the attention mask to account for the new token
features["attention_mask"] = self.insert_prefix_token(
features["attention_mask"], 1
)
if "token_type_ids" in features:
features["token_type_ids"] = self.insert_prefix_token(
features["token_type_ids"], 0
)
# In the original ColBERT, the original tokens do not attend to the expansion tokens (but the expansion tokens attend to original tokens)
if self.attend_to_expansion_tokens:
# Fill the attention mask with ones (we attend to "padding" tokens used for expansion)
Expand All @@ -1164,8 +1175,19 @@ def tokenize(
if pad_document:
extra_parameters["padding"] = "max_length"
features = self._first_module().tokenize(texts, **extra_parameters)
# Remplace the second token by the document prefix
features["input_ids"][:, 1] = self.document_prefix_id
# Create a new tensor with the document prefix ID inserted after the first token
features["input_ids"] = self.insert_prefix_token(
features["input_ids"], self.document_prefix_id
)
# Update the attention mask to account for the new token
features["attention_mask"] = self.insert_prefix_token(
features["attention_mask"], 1
)
if "token_type_ids" in features:
features["token_type_ids"] = self.insert_prefix_token(
features["token_type_ids"], 0
)

return features

def get_sentence_features(self, *features):
Expand Down

0 comments on commit 6c3b5c5

Please sign in to comment.