From 841158e2a03bfc0ebe45a96959f35d81064105fe Mon Sep 17 00:00:00 2001 From: Anhui-tqhuang Date: Thu, 14 Mar 2024 20:16:58 +0800 Subject: [PATCH] fix: black --- .../reranker/flagembedding_reranker.py | 33 ++++++-------- private_gpt/components/reranker/reranker.py | 44 +++++++------------ private_gpt/server/chat/chat_service.py | 2 +- 3 files changed, 32 insertions(+), 47 deletions(-) diff --git a/private_gpt/components/reranker/flagembedding_reranker.py b/private_gpt/components/reranker/flagembedding_reranker.py index b0ec07b35..e56888f72 100644 --- a/private_gpt/components/reranker/flagembedding_reranker.py +++ b/private_gpt/components/reranker/flagembedding_reranker.py @@ -1,3 +1,4 @@ +import logging from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex List, Tuple, @@ -8,8 +9,7 @@ from llama_index.core.indices.postprocessor import BaseNodePostprocessor from llama_index.core.schema import NodeWithScore, QueryBundle -from private_gpt.paths import models_path -from private_gpt.settings.settings import Settings +logger = logging.getLogger(__name__) class FlagEmbeddingRerankerComponent(BaseNodePostprocessor): @@ -22,23 +22,9 @@ class FlagEmbeddingRerankerComponent(BaseNodePostprocessor): Otherwise, return all nodes with score > cut_off. """ - reranker: FlagReranker = Field(description="Reranker class.") - top_n: int = Field(description="Top N nodes to return.") - cut_off: float = Field(description="Cut off score for nodes.") - - def __init__(self, settings: Settings) -> None: - path = models_path / "flagembedding_reranker" - top_n = settings.flagembedding_reranker.top_n - cut_off = settings.flagembedding_reranker.cut_off - reranker = FlagReranker( - model_name_or_path=path, - ) - - super().__init__( - top_n=top_n, - cut_off=cut_off, - reranker=reranker, - ) + top_n: int = Field(10, description="Top N nodes to return.") + cut_off: float = Field(0.0, description="Cut off score for nodes.") + reranker: FlagReranker = Field(..., description="Flag Reranker model.") @classmethod def class_name(cls) -> str: @@ -52,6 +38,9 @@ def _postprocess_nodes( if query_bundle is None: raise ValueError("Query bundle must be provided.") + logger.info("Postprocessing nodes with FlagEmbeddingReranker.") + logger.info(f"top_n: {self.top_n}, cut_off: {self.cut_off}") + query_str = query_bundle.query_str sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006 for node in nodes: @@ -65,6 +54,12 @@ def _postprocess_nodes( # cut off nodes with low scores res = [node for node in nodes if (node.score or 0.0) > self.cut_off] if len(res) > self.top_n: + logger.info( + "Number of nodes with score > cut_off is > top_n, returning all nodes with score > cut_off." + ) return res + logger.info( + "Number of nodes with score > cut_off is <= top_n, returning top_n nodes." + ) return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[: self.top_n] diff --git a/private_gpt/components/reranker/reranker.py b/private_gpt/components/reranker/reranker.py index cabf683cb..5468b9b8a 100644 --- a/private_gpt/components/reranker/reranker.py +++ b/private_gpt/components/reranker/reranker.py @@ -1,20 +1,15 @@ import logging -from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex - List, -) from injector import inject, singleton -from llama_index.core.bridge.pydantic import Field -from llama_index.core.indices.postprocessor import BaseNodePostprocessor -from llama_index.core.schema import NodeWithScore, QueryBundle +from private_gpt.paths import models_path from private_gpt.settings.settings import Settings logger = logging.getLogger(__name__) @singleton -class RerankerComponent(BaseNodePostprocessor): +class RerankerComponent: """Reranker component. - mode: Reranker mode. @@ -22,10 +17,6 @@ class RerankerComponent(BaseNodePostprocessor): """ - nodePostPorcesser: BaseNodePostprocessor = Field( - description="BaseNodePostprocessor class." - ) - @inject def __init__(self, settings: Settings) -> None: if settings.reranker.enabled is False: @@ -38,6 +29,8 @@ def __init__(self, settings: Settings) -> None: ) try: + from FlagEmbedding import FlagReranker # type: ignore + from private_gpt.components.reranker.flagembedding_reranker import ( FlagEmbeddingRerankerComponent, ) @@ -46,24 +39,21 @@ def __init__(self, settings: Settings) -> None: "Local dependencies not found, install with `poetry install --extras reranker-flagembedding`" ) from e - nodePostPorcesser = FlagEmbeddingRerankerComponent(settings) + path = models_path / "flagembedding_reranker" + + if settings.flagembedding_reranker is None: + raise ValueError("FlagEmbeddingReranker settings is not provided.") + + top_n = settings.flagembedding_reranker.top_n + cut_off = settings.flagembedding_reranker.cut_off + flagReranker = FlagReranker( + model_name_or_path=path, + ) + self.nodePostPorcesser = FlagEmbeddingRerankerComponent( + top_n=top_n, cut_off=cut_off, reranker=flagReranker + ) case _: raise ValueError( "Reranker mode not supported, currently only support flagembedding." ) - - super().__init__( - nodePostPorcesser=nodePostPorcesser, - ) - - @classmethod - def class_name(cls) -> str: - return "Reranker" - - def _postprocess_nodes( - self, - nodes: List[NodeWithScore], # noqa: UP006 - query_bundle: QueryBundle | None = None, - ) -> List[NodeWithScore]: # noqa: UP006 - return self.nodePostPorcesser._postprocess_nodes(nodes, query_bundle) diff --git a/private_gpt/server/chat/chat_service.py b/private_gpt/server/chat/chat_service.py index 501c70fbb..da67dd335 100644 --- a/private_gpt/server/chat/chat_service.py +++ b/private_gpt/server/chat/chat_service.py @@ -129,7 +129,7 @@ def _chat_engine( ] if self.reranker_component: - node_postprocessors.append(self.reranker_component) + node_postprocessors.append(self.reranker_component.nodePostPorcesser) return ContextChatEngine.from_defaults( system_prompt=system_prompt,