From 9cef24b56953f179a1e1399b4dc994346597a808 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 5 Feb 2025 20:36:11 +0800 Subject: [PATCH 1/2] allow skip intent_classification --- wren-ai-service/src/config.py | 1 + wren-ai-service/src/globals.py | 1 + wren-ai-service/src/web/v1/services/ask.py | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/wren-ai-service/src/config.py b/wren-ai-service/src/config.py index dc19b02c8..d1ead0d60 100644 --- a/wren-ai-service/src/config.py +++ b/wren-ai-service/src/config.py @@ -32,6 +32,7 @@ class Settings(BaseSettings): allow_using_db_schemas_without_pruning: bool = Field(default=False) # generation config + allow_intent_classification: bool = Field(default=True) allow_sql_generation_reasoning: bool = Field(default=True) # engine config diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 9ed3e31c3..123622d7e 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -124,6 +124,7 @@ def create_service_container( **pipe_components["sql_summary"], ), }, + allow_intent_classification=settings.allow_intent_classification, allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning, **query_cache, ), diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 155cb67cf..468dd2cce 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -102,6 +102,7 @@ class AskService: def __init__( self, pipelines: Dict[str, BasicPipeline], + allow_intent_classification: bool = True, allow_sql_generation_reasoning: bool = True, maxsize: int = 1_000_000, ttl: int = 120, @@ -111,6 +112,7 @@ def __init__( maxsize=maxsize, ttl=ttl ) self._allow_sql_generation_reasoning = allow_sql_generation_reasoning + self._allow_intent_classification = allow_intent_classification def _is_stopped(self, query_id: str): if ( @@ -173,7 +175,7 @@ async def ask( for result in historical_question_result ] sql_generation_reasoning = "" - else: + elif self._allow_intent_classification: intent_classification_result = ( await self._pipelines["intent_classification"].run( query=ask_request.query, From 3fd08c70d5ba4af2963b782f48069f7becf16067 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 5 Feb 2025 20:40:44 +0800 Subject: [PATCH 2/2] fix bug --- wren-ai-service/src/web/v1/services/ask.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 468dd2cce..26145986f 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -146,6 +146,8 @@ async def ask( error_message = "" try: + user_query = ask_request.query + # ask status can be understanding, searching, generating, finished, failed, stopped # we will need to handle business logic for each status if not self._is_stopped(query_id): @@ -154,7 +156,7 @@ async def ask( ) historical_question = await self._pipelines["historical_question"].run( - query=ask_request.query, + query=user_query, id=ask_request.project_id, ) @@ -178,7 +180,7 @@ async def ask( elif self._allow_intent_classification: intent_classification_result = ( await self._pipelines["intent_classification"].run( - query=ask_request.query, + query=user_query, history=ask_request.history, id=ask_request.project_id, configuration=ask_request.configurations, @@ -190,11 +192,8 @@ async def ask( ) intent_reasoning = intent_classification_result.get("reasoning") - user_query = ( - ask_request.query - if not rephrased_question - else rephrased_question - ) + if rephrased_question: + user_query = rephrased_question if intent == "MISLEADING_QUERY": self._ask_results[query_id] = AskResultResponse(