Skip to content

Commit

Permalink
✨ add score threshold in retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 21, 2024
1 parent 0150905 commit fb5dee4
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 5 deletions.
3 changes: 2 additions & 1 deletion backend/rag_1/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ vectorstore:
retriever:
_target_: langchain_core.vectorstores.VectorStoreRetriever
vectorstore: ${..vectorstore}
search_type: "similarity"
search_type: "similarity_score_threshold"
search_kwargs:
k: 4
score_threshold: 0.5

ingest:
clear_database: True
Expand Down
5 changes: 3 additions & 2 deletions backend/rag_2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ store:
root_path: "${..path.database}/multi_vector_retriever_metadata/"

retriever:
_target_: langchain.retrievers.multi_vector.MultiVectorRetriever
_target_: backend.utils.multi_vector.ThresholdedMultiVectorRetriever
vectorstore: ${..vectorstore}
byte_store: ${..store}
id_key: "doc_id"
search_type: "similarity"
search_type: "similarity_score_threshold"
search_kwargs:
k: 4
score_threshold: 0.5

ingest:
clear_database: True
Expand Down
5 changes: 3 additions & 2 deletions backend/rag_3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ store:
root_path: "${..path.database}/multi_vector_retriever_metadata/"

retriever:
_target_: langchain.retrievers.multi_vector.MultiVectorRetriever
_target_: backend.utils.multi_vector.ThresholdedMultiVectorRetriever
vectorstore: ${..vectorstore}
byte_store: ${..store}
id_key: "doc_id"
search_type: "similarity"
search_type: "similarity_score_threshold"
search_kwargs:
k: 4
score_threshold: 0.5

ingest:
clear_database: True
Expand Down
125 changes: 125 additions & 0 deletions backend/utils/multi_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""MultiVectorRetriever with `similarity_score_threshold` option."""

from collections.abc import Collection
from typing import ClassVar

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from pydantic import root_validator


class ThresholdedMultiVectorRetriever(MultiVectorRetriever):
"""MultiVectorRetriever the ability to search by similarity with a threshold."""

search_type: str = "similarity"
"""Type of search to perform. Defaults to "similarity"."""
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"mmr",
)

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

@root_validator()
def validate_search_type(cls, values: dict) -> dict:
"""Validate search type."""
search_type = values["search_type"]
if search_type not in cls.allowed_search_types:
raise ValueError(
f"search_type of {search_type} not allowed. Valid values are: "
f"{cls.allowed_search_types}"
)
if search_type == "similarity_score_threshold":
score_threshold = values["search_kwargs"].get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
raise ValueError(
"`score_threshold` is not specified with a float value(0~1) "
"in `search_kwargs`."
)
return values

def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun, # noqa: ARG002
) -> list[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
if self.search_type == "similarity":
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
sub_docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
elif self.search_type == "mmr":
sub_docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")

# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]

async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun, # noqa: ARG002
) -> list[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
if self.search_type == "similarity":
sub_docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
elif self.search_type == "similarity_score_threshold":
sub_docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
elif self.search_type == "mmr":
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")

# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = await self.docstore.amget(ids)
return [d for d in docs if d is not None]

0 comments on commit fb5dee4

Please sign in to comment.