Skip to content

Commit

Permalink
Allow to set the prefixes for stanford-nlp models (#55)
Browse files Browse the repository at this point in the history
* Allow to set the prefixes for stanford-nlp models

* Version bump

* Add documentation for the jina-colbert-v2 model
  • Loading branch information
NohTow authored Sep 13, 2024
1 parent 0c74287 commit d526d98
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 15 deletions.
4 changes: 2 additions & 2 deletions docs/api/losses/Contrastive.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If th

ColBERT model.

- **score_metric** – defaults to `<function colbert_scores at 0x14073dcf0>`
- **score_metric** – defaults to `<function colbert_scores at 0x7fc43e97f7e0>`

ColBERT scoring function. Defaults to colbert_scores.

Expand Down Expand Up @@ -228,7 +228,7 @@ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If th

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down
4 changes: 2 additions & 2 deletions docs/api/losses/Distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Distillation loss for ColBERT model. The loss is computed with respect to the fo

SentenceTransformer model.

- **score_metric** (*Callable*) – defaults to `<function colbert_kd_scores at 0x16ec65120>`
- **score_metric** (*Callable*) – defaults to `<function colbert_kd_scores at 0x7fc43ea4bc40>`

Function that returns a score between two sequences of embeddings.

Expand Down Expand Up @@ -232,7 +232,7 @@ Distillation loss for ColBERT model. The loss is computed with respect to the fo

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down
8 changes: 4 additions & 4 deletions docs/api/models/ColBERT.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul

- **bias** (*bool*) – defaults to `False`

- **query_prefix** (*str | None*) – defaults to `[Q] `
- **query_prefix** (*str | None*) – defaults to `None`

Prefix to add to the queries.

- **document_prefix** (*str | None*) – defaults to `[D] `
- **document_prefix** (*str | None*) – defaults to `None`

Prefix to add to the documents.

Expand Down Expand Up @@ -494,7 +494,7 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down Expand Up @@ -602,7 +602,7 @@ Loads or creates a ColBERT model that can be used to map sentences / text to mul

**Returns**

*list*: A list of pooled embeddings for each document.
*list[torch.Tensor]*: A list of pooled embeddings for each document.

???- note "pop"

Expand Down
15 changes: 14 additions & 1 deletion docs/api/models/Dense.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ Performs linear projection on the token embeddings to a lower dimension.

- **dense** (*sentence_transformers.models.Dense.Dense*)

???- note "from_stanford_weights"

Load the weight of the Dense layer using weights from a stanford-nlp checkpoint.

**Parameters**

- **model_name_or_path** (*str | os.PathLike*)
- **cache_folder** (*str | os.PathLike | None*) – defaults to `None`
- **revision** (*str | None*) – defaults to `None`
- **local_files_only** (*bool | None*) – defaults to `None`
- **token** (*str | bool | None*) – defaults to `None`
- **use_auth_token** (*str | bool | None*) – defaults to `None`

???- note "get_buffer"

Return the buffer given by ``target`` if it exists, otherwise throw an error.
Expand Down Expand Up @@ -244,7 +257,7 @@ Performs linear projection on the token embeddings to a lower dimension.

Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.
If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. .. warning:: If :attr:`assign` is ``True`` the optimizer must be created after the call to :attr:`load_state_dict` unless :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. * **unexpected_keys** is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: If a parameter or buffer is registered as ``None`` and its corresponding key exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ``RuntimeError``.

**Parameters**

Expand Down
27 changes: 26 additions & 1 deletion docs/models/models.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
# Available models

Here is a list of the pre-trained ColBERT models available in PyLate along with their results on BEIR:
!!! tip
Following an update, all the models trained using the stanford-nlp ColBERT library or RAGatouille should be compatible with PyLate natively.
You can simply load the model in PyLate:

```python
from pylate import models

model = models.ColBERT(
model_name_or_path="colbert-ir/colbertv2.0",
)
```

Note that some models can use non-default parameters (e.g, prefixes) and so you might have to specify them, e.g for the [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2) model, you need to specify the prefixes and allow the query tokens to attend to expansion tokens:
```python
model = models.ColBERT(
model_name_or_path="jinaai/jina-colbert-v2",
query_prefix="[QueryMarker]",
document_prefix="[DocumentMarker]",
attend_to_expansion_tokens=True,
trust_remote_code=True,
)
```


Here is a list of some of the pre-trained ColBERT models available in PyLate along with their results on BEIR:

=== "Table"

| Model | BEIR AVG | NFCorpus | SciFact | SCIDOCS | FiQA2018 | TRECCOVID | HotpotQA | Touche2020 | ArguAna | ClimateFEVER | FEVER | QuoraRetrieval | NQ | DBPedia |
|---------------------------------------|----------|----------|---------|---------|----------|-----------|----------|------------|---------|--------------|-------|----------------|------|---------|
| [lightonai/colbertv2.0](https://huggingface.co/lightonai/colbertv2.0) | 50.02 | 33.8 | 69.3 | 15.4 | 35.6 | 73.3 | 66.7 | 26.3 | 46.3 | 17.6 | 78.5 | 85.2 | 56.2 | 44.6 |
| [answerdotai/answerai-colbert-small-v1](https://huggingface.co/answerdotai/answerai-colbert-small-v1) | 53.79 | 37.3 | 74.77 | 18.42 | 41.15 | 84.59 | 76.11 | 25.69 | 50.09 | 33.07 | 90.96 | 87.72 | 59.1 | 45.58 |
| [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2) | 53.1 | 34.6 | 67.8 | 18.6 | 40.8 | 83.4 | 76.6 | 27.4 | 36.6 | 23.9 | 80.05 | 88.7 | 64.0 | 47.1 |


???+ note
Expand Down
2 changes: 1 addition & 1 deletion pylate/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (1, 1, 1)
VERSION = (1, 1, 2)

__version__ = ".".join(map(str, VERSION))
15 changes: 11 additions & 4 deletions pylate/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def __init__(
truncate_dim: int | None = None,
embedding_size: int | None = None,
bias: bool = False,
query_prefix: str | None = "[Q] ",
document_prefix: str | None = "[D] ",
query_prefix: str | None = None,
document_prefix: str | None = None,
add_special_tokens: bool = True,
truncation: bool = True,
query_length: int | None = None,
Expand Down Expand Up @@ -262,8 +262,10 @@ def __init__(
)
)
# Setting the prefixes from stanford-nlp models
self.query_prefix = "[unused0]"
self.document_prefix = "[unused1]"
if self.query_prefix is None:
self.query_prefix = "[unused0]"
if self.document_prefix is None:
self.document_prefix = "[unused1]"
logger.warning("Loaded the ColBERT model from Stanford NLP.")
else:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
Expand Down Expand Up @@ -308,6 +310,11 @@ def __init__(
self.to(device)
self.is_hpu_graph_enabled = False

if self.query_prefix is None:
self.query_prefix = "[Q] "
if self.document_prefix is None:
self.document_prefix = "[D] "

# Try adding the prefixes to the tokenizer. We call resize_token_embeddings twice to ensure the tokens are added only if resize_token_embeddings works. There should be a better way to do this.
try:
self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))
Expand Down

0 comments on commit d526d98

Please sign in to comment.