-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0150905
commit fb5dee4
Showing
4 changed files
with
133 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""MultiVectorRetriever with `similarity_score_threshold` option.""" | ||
|
||
from collections.abc import Collection | ||
from typing import ClassVar | ||
|
||
from langchain.retrievers.multi_vector import MultiVectorRetriever | ||
from langchain_core.callbacks import ( | ||
AsyncCallbackManagerForRetrieverRun, | ||
CallbackManagerForRetrieverRun, | ||
) | ||
from langchain_core.documents import Document | ||
from pydantic import root_validator | ||
|
||
|
||
class ThresholdedMultiVectorRetriever(MultiVectorRetriever): | ||
"""MultiVectorRetriever the ability to search by similarity with a threshold.""" | ||
|
||
search_type: str = "similarity" | ||
"""Type of search to perform. Defaults to "similarity".""" | ||
allowed_search_types: ClassVar[Collection[str]] = ( | ||
"similarity", | ||
"similarity_score_threshold", | ||
"mmr", | ||
) | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
arbitrary_types_allowed = True | ||
|
||
@root_validator() | ||
def validate_search_type(cls, values: dict) -> dict: | ||
"""Validate search type.""" | ||
search_type = values["search_type"] | ||
if search_type not in cls.allowed_search_types: | ||
raise ValueError( | ||
f"search_type of {search_type} not allowed. Valid values are: " | ||
f"{cls.allowed_search_types}" | ||
) | ||
if search_type == "similarity_score_threshold": | ||
score_threshold = values["search_kwargs"].get("score_threshold") | ||
if (score_threshold is None) or (not isinstance(score_threshold, float)): | ||
raise ValueError( | ||
"`score_threshold` is not specified with a float value(0~1) " | ||
"in `search_kwargs`." | ||
) | ||
return values | ||
|
||
def _get_relevant_documents( | ||
self, | ||
query: str, | ||
*, | ||
run_manager: CallbackManagerForRetrieverRun, # noqa: ARG002 | ||
) -> list[Document]: | ||
"""Get documents relevant to a query. | ||
Args: | ||
query: String to find relevant documents for | ||
run_manager: The callbacks handler to use | ||
Returns: | ||
List of relevant documents | ||
""" | ||
if self.search_type == "similarity": | ||
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) | ||
elif self.search_type == "similarity_score_threshold": | ||
sub_docs_and_similarities = ( | ||
self.vectorstore.similarity_search_with_relevance_scores( | ||
query, **self.search_kwargs | ||
) | ||
) | ||
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] | ||
elif self.search_type == "mmr": | ||
sub_docs = self.vectorstore.max_marginal_relevance_search( | ||
query, **self.search_kwargs | ||
) | ||
else: | ||
raise ValueError(f"search_type of {self.search_type} not allowed.") | ||
|
||
# We do this to maintain the order of the ids that are returned | ||
ids = [] | ||
for d in sub_docs: | ||
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: | ||
ids.append(d.metadata[self.id_key]) | ||
docs = self.docstore.mget(ids) | ||
return [d for d in docs if d is not None] | ||
|
||
async def _aget_relevant_documents( | ||
self, | ||
query: str, | ||
*, | ||
run_manager: AsyncCallbackManagerForRetrieverRun, # noqa: ARG002 | ||
) -> list[Document]: | ||
"""Asynchronously get documents relevant to a query. | ||
Args: | ||
query: String to find relevant documents for | ||
run_manager: The callbacks handler to use | ||
Returns: | ||
List of relevant documents | ||
""" | ||
if self.search_type == "similarity": | ||
sub_docs = await self.vectorstore.asimilarity_search( | ||
query, **self.search_kwargs | ||
) | ||
elif self.search_type == "similarity_score_threshold": | ||
sub_docs_and_similarities = ( | ||
await self.vectorstore.asimilarity_search_with_relevance_scores( | ||
query, **self.search_kwargs | ||
) | ||
) | ||
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] | ||
elif self.search_type == "mmr": | ||
sub_docs = await self.vectorstore.amax_marginal_relevance_search( | ||
query, **self.search_kwargs | ||
) | ||
else: | ||
raise ValueError(f"search_type of {self.search_type} not allowed.") | ||
|
||
# We do this to maintain the order of the ids that are returned | ||
ids = [] | ||
for d in sub_docs: | ||
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: | ||
ids.append(d.metadata[self.id_key]) | ||
docs = await self.docstore.amget(ids) | ||
return [d for d in docs if d is not None] |