Skip to content

Commit

Permalink
Chore/lower bs for long docs (#134)
Browse files Browse the repository at this point in the history
* chore: auto-adjust batch size for longer docs

* chore: move function to more appropriate file

* fix bsize calc
  • Loading branch information
bclavie authored Feb 13, 2024
1 parent 4fbc9ce commit b7ae28a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.7"
version = "0.0.7post2"
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <[email protected]>"]
license = "Apache-2.0"
Expand Down
4 changes: 2 additions & 2 deletions ragatouille/RAGPretrainedModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def rerank(
documents: list[str],
k: int = 10,
zero_index_ranks: bool = False,
bsize: int = 64,
bsize: Union[Literal["auto"], int] = "auto",
):
"""Encode documents and rerank them in-memory. Performance degrades rapidly with more documents.
Expand Down Expand Up @@ -337,7 +337,7 @@ def rerank(
def encode(
self,
documents: list[str],
bsize: int = 32,
bsize: Union[Literal["auto"], int] = "auto",
document_metadatas: Optional[list[dict]] = None,
verbose: bool = True,
max_document_length: Union[Literal["auto"], int] = "auto",
Expand Down
2 changes: 1 addition & 1 deletion ragatouille/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.7"
__version__ = "0.0.7post2"
from .RAGPretrainedModel import RAGPretrainedModel
from .RAGTrainer import RAGTrainer

Expand Down
37 changes: 34 additions & 3 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from ragatouille.models.base import LateInteractionModel

# TODO: Move all bsize related calcs to `_set_bsize()`


class ColBERT(LateInteractionModel):
def __init__(
Expand Down Expand Up @@ -630,7 +632,7 @@ def _index_free_retrieve(
k: int,
max_tokens: Union[Literal["auto"], int] = "auto",
zero_index: bool = False,
bsize: int = 32,
bsize: Union[Literal["auto"], int] = "auto",
):
self._set_inference_max_tokens(documents=documents, max_tokens=max_tokens)

Expand Down Expand Up @@ -663,8 +665,12 @@ def _index_free_retrieve(
)

def _encode_index_free_queries(
self, queries: Union[str, list[str]], bsize: int = 32
self,
queries: Union[str, list[str]],
bsize: Union[Literal["auto"], int] = "auto",
):
if bsize == "auto":
bsize = 32
if isinstance(queries, str):
queries = [queries]
maxlen = max([int(len(x.split(" ")) * 1.35) for x in queries])
Expand All @@ -678,8 +684,31 @@ def _encode_index_free_queries(
return embedded_queries

def _encode_index_free_documents(
self, documents: list[str], bsize: int = 32, verbose: bool = True
self,
documents: list[str],
bsize: Union[Literal["auto"], int] = "auto",
verbose: bool = True,
):
if bsize == "auto":
bsize = 32
if self.inference_ckpt.doc_tokenizer.doc_maxlen > 512:
bsize = max(
1,
int(
32
/ (
2
** round(
math.log(
self.inference_ckpt.doc_tokenizer.doc_maxlen, 2
)
)
/ 512
)
),
)
print("BSIZE:")
print(bsize)
embedded_docs = self.inference_ckpt.docFromText(
documents, bsize=bsize, showprogress=verbose
)[0]
Expand All @@ -694,6 +723,8 @@ def rank(
zero_index_ranks: bool = False,
bsize: int = 32,
):
self._set_inference_max_tokens(documents=documents, max_tokens="auto")
self.inference_ckpt_len_set = False
return self._index_free_retrieve(
query, documents, k, zero_index=zero_index_ranks, bsize=bsize
)
Expand Down

0 comments on commit b7ae28a

Please sign in to comment.