diff --git a/private_gpt/components/reranker/reranker.py b/private_gpt/components/reranker/reranker.py index 96671a1904..aa4f124598 100644 --- a/private_gpt/components/reranker/reranker.py +++ b/private_gpt/components/reranker/reranker.py @@ -7,6 +7,7 @@ from llama_index.postprocessor.types import BaseNodePostprocessor from private_gpt.settings.settings import Settings + @singleton class RerankerComponent(BaseNodePostprocessor): """ @@ -17,6 +18,7 @@ class RerankerComponent(BaseNodePostprocessor): If the number of nodes with score > cut_off is <= top_n, then return top_n nodes. Otherwise, return all nodes with score > cut_off. """ + reranker: FlagReranker = Field(description="Reranker class.") top_n: int = Field(description="Top N nodes to return.") cut_off: float = Field(description="Cut off score for nodes.") @@ -66,6 +68,4 @@ def _postprocess_nodes( if len(res) > self.top_n: return res - return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[ - : self.top_n - ] + return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[: self.top_n] diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index eb1c296b95..4a50f2d860 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -111,7 +111,7 @@ class RerankerSettings(BaseModel): ) hf_model_name: str = Field( "BAAI/bge-reranker-large", - description="Name of the HuggingFace model to use for reranking" + description="Name of the HuggingFace model to use for reranking", ) top_n: int = Field( 5,