From d526d98b1b44899c2ba7e2c20addb22be2faecbd Mon Sep 17 00:00:00 2001 From: Antoine Chaffin <38869395+NohTow@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:03:14 +0200 Subject: [PATCH] Allow to set the prefixes for stanford-nlp models (#55) * Allow to set the prefixes for stanford-nlp models * Version bump * Add documentation for the jina-colbert-v2 model --- docs/api/losses/Contrastive.md | 4 ++-- docs/api/losses/Distillation.md | 4 ++-- docs/api/models/ColBERT.md | 8 ++++---- docs/api/models/Dense.md | 15 ++++++++++++++- docs/models/models.md | 27 ++++++++++++++++++++++++++- pylate/__version__.py | 2 +- pylate/models/colbert.py | 15 +++++++++++---- 7 files changed, 60 insertions(+), 15 deletions(-) diff --git a/docs/api/losses/Contrastive.md b/docs/api/losses/Contrastive.md index 29a2bf0..2c4ecae 100644 --- a/docs/api/losses/Contrastive.md +++ b/docs/api/losses/Contrastive.md @@ -10,7 +10,7 @@ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If th ColBERT model. -- **score_metric** – defaults to `` +- **score_metric** – defaults to `` ColBERT scoring function. Defaults to colbert_scores. @@ -228,7 +228,7 @@ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If th Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. - If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. + If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. **Parameters** diff --git a/docs/api/losses/Distillation.md b/docs/api/losses/Distillation.md index 623b40e..62ec61e 100644 --- a/docs/api/losses/Distillation.md +++ b/docs/api/losses/Distillation.md @@ -10,7 +10,7 @@ Distillation loss for ColBERT model. The loss is computed with respect to the fo SentenceTransformer model. -- **score_metric** (*Callable*) – defaults to `` +- **score_metric** (*Callable*) – defaults to `` Function that returns a score between two sequences of embeddings. @@ -232,7 +232,7 @@ Distillation loss for ColBERT model. The loss is computed with respect to the fo Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. - If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. + If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. **Parameters** diff --git a/docs/api/models/ColBERT.md b/docs/api/models/ColBERT.md index b2b7a19..ce32c86 100644 --- a/docs/api/models/ColBERT.md +++ b/docs/api/models/ColBERT.md @@ -64,11 +64,11 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul - **bias** (*bool*) – defaults to `False` -- **query_prefix** (*str | None*) – defaults to `[Q] ` +- **query_prefix** (*str | None*) – defaults to `None` Prefix to add to the queries. -- **document_prefix** (*str | None*) – defaults to `[D] ` +- **document_prefix** (*str | None*) – defaults to `None` Prefix to add to the documents. @@ -494,7 +494,7 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. - If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. + If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. **Parameters** @@ -602,7 +602,7 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul **Returns** - *list*: A list of pooled embeddings for each document. + *list[torch.Tensor]*: A list of pooled embeddings for each document. ???- note "pop" diff --git a/docs/api/models/Dense.md b/docs/api/models/Dense.md index b9007dd..a78e887 100644 --- a/docs/api/models/Dense.md +++ b/docs/api/models/Dense.md @@ -176,6 +176,19 @@ Performs linear projection on the token embeddings to a lower dimension. - **dense** (*sentence_transformers.models.Dense.Dense*) +???- note "from_stanford_weights" + + Load the weight of the Dense layer using weights from a stanford-nlp checkpoint. + + **Parameters** + + - **model_name_or_path** (*str | os.PathLike*) + - **cache_folder** (*str | os.PathLike | None*) – defaults to `None` + - **revision** (*str | None*) – defaults to `None` + - **local_files_only** (*bool | None*) – defaults to `None` + - **token** (*str | bool | None*) – defaults to `None` + - **use_auth_token** (*str | bool | None*) – defaults to `None` + ???- note "get_buffer" Return the buffer given by ``target`` if it exists, otherwise throw an error. @@ -244,7 +257,7 @@ Performs linear projection on the token embeddings to a lower dimension. Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. - If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. + If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``. **Parameters** diff --git a/docs/models/models.md b/docs/models/models.md index 8c9f3a6..0ef81c1 100644 --- a/docs/models/models.md +++ b/docs/models/models.md @@ -1,6 +1,30 @@ # Available models -Here is a list of the pre-trained ColBERT models available in PyLate along with their results on BEIR: +!!! tip + Following an update, all the models trained using the stanford-nlp ColBERT library or RAGatouille should be compatible with PyLate natively. + You can simply load the model in PyLate: + + ```python + from pylate import models + + model = models.ColBERT( + model_name_or_path="colbert-ir/colbertv2.0", + ) + ``` + + Note that some models can use non-default parameters (e.g, prefixes) and so you might have to specify them, e.g for the [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2) model, you need to specify the prefixes and allow the query tokens to attend to expansion tokens: + ```python + model = models.ColBERT( + model_name_or_path="jinaai/jina-colbert-v2", + query_prefix="[QueryMarker]", + document_prefix="[DocumentMarker]", + attend_to_expansion_tokens=True, + trust_remote_code=True, + ) + ``` + + +Here is a list of some of the pre-trained ColBERT models available in PyLate along with their results on BEIR: === "Table" @@ -8,6 +32,7 @@ Here is a list of the pre-trained ColBERT models available in PyLate along with |---------------------------------------|----------|----------|---------|---------|----------|-----------|----------|------------|---------|--------------|-------|----------------|------|---------| | [lightonai/colbertv2.0](https://huggingface.co/lightonai/colbertv2.0) | 50.02 | 33.8 | 69.3 | 15.4 | 35.6 | 73.3 | 66.7 | 26.3 | 46.3 | 17.6 | 78.5 | 85.2 | 56.2 | 44.6 | | [answerdotai/answerai-colbert-small-v1](https://huggingface.co/answerdotai/answerai-colbert-small-v1) | 53.79 | 37.3 | 74.77 | 18.42 | 41.15 | 84.59 | 76.11 | 25.69 | 50.09 | 33.07 | 90.96 | 87.72 | 59.1 | 45.58 | +| [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2) | 53.1 | 34.6 | 67.8 | 18.6 | 40.8 | 83.4 | 76.6 | 27.4 | 36.6 | 23.9 | 80.05 | 88.7 | 64.0 | 47.1 | ???+ note diff --git a/pylate/__version__.py b/pylate/__version__.py index 5b3343a..e851482 100644 --- a/pylate/__version__.py +++ b/pylate/__version__.py @@ -1,3 +1,3 @@ -VERSION = (1, 1, 1) +VERSION = (1, 1, 2) __version__ = ".".join(map(str, VERSION)) diff --git a/pylate/models/colbert.py b/pylate/models/colbert.py index df5ac82..987001a 100644 --- a/pylate/models/colbert.py +++ b/pylate/models/colbert.py @@ -206,8 +206,8 @@ def __init__( truncate_dim: int | None = None, embedding_size: int | None = None, bias: bool = False, - query_prefix: str | None = "[Q] ", - document_prefix: str | None = "[D] ", + query_prefix: str | None = None, + document_prefix: str | None = None, add_special_tokens: bool = True, truncation: bool = True, query_length: int | None = None, @@ -262,8 +262,10 @@ def __init__( ) ) # Setting the prefixes from stanford-nlp models - self.query_prefix = "[unused0]" - self.document_prefix = "[unused1]" + if self.query_prefix is None: + self.query_prefix = "[unused0]" + if self.document_prefix is None: + self.document_prefix = "[unused1]" logger.warning("Loaded the ColBERT model from Stanford NLP.") else: # Add a linear projection layer to the model in order to project the embeddings to the desired size @@ -308,6 +310,11 @@ def __init__( self.to(device) self.is_hpu_graph_enabled = False + if self.query_prefix is None: + self.query_prefix = "[Q] " + if self.document_prefix is None: + self.document_prefix = "[D] " + # Try adding the prefixes to the tokenizer. We call resize_token_embeddings twice to ensure the tokens are added only if resize_token_embeddings works. There should be a better way to do this. try: self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))