Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to set the prefixes for stanford-nlp models #55

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api/losses/Contrastive.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If th

ColBERT model.

- **score_metric** – defaults to `<function colbert_scores at 0x14073dcf0>`
- **score_metric** – defaults to `<function colbert_scores at 0x7fc43e97f7e0>`

ColBERT scoring function. Defaults to colbert_scores.

Expand Down Expand Up @@ -228,7 +228,7 @@ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If th

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down
4 changes: 2 additions & 2 deletions docs/api/losses/Distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Distillation loss for ColBERT model. The loss is computed with respect to the fo

SentenceTransformer model.

- **score_metric** (*Callable*) – defaults to `<function colbert_kd_scores at 0x16ec65120>`
- **score_metric** (*Callable*) – defaults to `<function colbert_kd_scores at 0x7fc43ea4bc40>`

Function that returns a score between two sequences of embeddings.

Expand Down Expand Up @@ -232,7 +232,7 @@ Distillation loss for ColBERT model. The loss is computed with respect to the fo

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down
8 changes: 4 additions & 4 deletions docs/api/models/ColBERT.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul

- **bias** (*bool*) – defaults to `False`

- **query_prefix** (*str | None*) – defaults to `[Q] `
- **query_prefix** (*str | None*) – defaults to `None`

Prefix to add to the queries.

- **document_prefix** (*str | None*) – defaults to `[D] `
- **document_prefix** (*str | None*) – defaults to `None`

Prefix to add to the documents.

Expand Down Expand Up @@ -494,7 +494,7 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down Expand Up @@ -602,7 +602,7 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul

**Returns**

*list*: A list of pooled embeddings for each document.
*list[torch.Tensor]*: A list of pooled embeddings for each document.

???- note "pop"

Expand Down
15 changes: 14 additions & 1 deletion docs/api/models/Dense.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ Performs linear projection on the token embeddings to a lower dimension.

- **dense** (*sentence_transformers.models.Dense.Dense*)

???- note "from_stanford_weights"

Load the weight of the Dense layer using weights from a stanford-nlp checkpoint.

**Parameters**

- **model_name_or_path** (*str | os.PathLike*)
- **cache_folder** (*str | os.PathLike | None*) – defaults to `None`
- **revision** (*str | None*) – defaults to `None`
- **local_files_only** (*bool | None*) – defaults to `None`
- **token** (*str | bool | None*) – defaults to `None`
- **use_auth_token** (*str | bool | None*) – defaults to `None`

???- note "get_buffer"

Return the buffer given by ``target`` if it exists, otherwise throw an error.
Expand Down Expand Up @@ -244,7 +257,7 @@ Performs linear projection on the token embeddings to a lower dimension.

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down
27 changes: 26 additions & 1 deletion docs/models/models.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
# Available models

Here is a list of the pre-trained ColBERT models available in PyLate along with their results on BEIR:
!!! tip
Following an update, all the models trained using the stanford-nlp ColBERT library or RAGatouille should be compatible with PyLate natively.
You can simply load the model in PyLate:

```python
from pylate import models

model = models.ColBERT(
model_name_or_path="colbert-ir/colbertv2.0",
)
```

Note that some models can use non-default parameters (e.g, prefixes) and so you might have to specify them, e.g for the [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2) model, you need to specify the prefixes and allow the query tokens to attend to expansion tokens:
```python
model = models.ColBERT(
model_name_or_path="jinaai/jina-colbert-v2",
query_prefix="[QueryMarker]",
document_prefix="[DocumentMarker]",
attend_to_expansion_tokens=True,
trust_remote_code=True,
)
```


Here is a list of some of the pre-trained ColBERT models available in PyLate along with their results on BEIR:

=== "Table"

| Model | BEIR AVG | NFCorpus | SciFact | SCIDOCS | FiQA2018 | TRECCOVID | HotpotQA | Touche2020 | ArguAna | ClimateFEVER | FEVER | QuoraRetrieval | NQ | DBPedia |
|---------------------------------------|----------|----------|---------|---------|----------|-----------|----------|------------|---------|--------------|-------|----------------|------|---------|
| [lightonai/colbertv2.0](https://huggingface.co/lightonai/colbertv2.0) | 50.02 | 33.8 | 69.3 | 15.4 | 35.6 | 73.3 | 66.7 | 26.3 | 46.3 | 17.6 | 78.5 | 85.2 | 56.2 | 44.6 |
| [answerdotai/answerai-colbert-small-v1](https://huggingface.co/answerdotai/answerai-colbert-small-v1) | 53.79 | 37.3 | 74.77 | 18.42 | 41.15 | 84.59 | 76.11 | 25.69 | 50.09 | 33.07 | 90.96 | 87.72 | 59.1 | 45.58 |
| [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2) | 53.1 | 34.6 | 67.8 | 18.6 | 40.8 | 83.4 | 76.6 | 27.4 | 36.6 | 23.9 | 80.05 | 88.7 | 64.0 | 47.1 |


???+ note
Expand Down
2 changes: 1 addition & 1 deletion pylate/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (1, 1, 1)
VERSION = (1, 1, 2)

__version__ = ".".join(map(str, VERSION))
15 changes: 11 additions & 4 deletions pylate/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def __init__(
truncate_dim: int | None = None,
embedding_size: int | None = None,
bias: bool = False,
query_prefix: str | None = "[Q] ",
document_prefix: str | None = "[D] ",
query_prefix: str | None = None,
document_prefix: str | None = None,
add_special_tokens: bool = True,
truncation: bool = True,
query_length: int | None = None,
Expand Down Expand Up @@ -262,8 +262,10 @@ def __init__(
)
)
# Setting the prefixes from stanford-nlp models
self.query_prefix = "[unused0]"
self.document_prefix = "[unused1]"
if self.query_prefix is None:
self.query_prefix = "[unused0]"
if self.document_prefix is None:
self.document_prefix = "[unused1]"
logger.warning("Loaded the ColBERT model from Stanford NLP.")
else:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
Expand Down Expand Up @@ -308,6 +310,11 @@ def __init__(
self.to(device)
self.is_hpu_graph_enabled = False

if self.query_prefix is None:
self.query_prefix = "[Q] "
if self.document_prefix is None:
self.document_prefix = "[D] "

# Try adding the prefixes to the tokenizer. We call resize_token_embeddings twice to ensure the tokens are added only if resize_token_embeddings works. There should be a better way to do this.
try:
self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))
Expand Down
Loading