diff --git a/wren-ai-service/demo/utils.py b/wren-ai-service/demo/utils.py index 194064283..a72ddfdef 100644 --- a/wren-ai-service/demo/utils.py +++ b/wren-ai-service/demo/utils.py @@ -5,7 +5,7 @@ import time import uuid from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import orjson import pandas as pd @@ -183,8 +183,11 @@ def on_change_sql_generation_reasoning(): ] -def on_click_regenerate_sql(changed_sql_generation_reasoning: str): +def on_click_regenerate_sql( + retrieved_tables: list[str], changed_sql_generation_reasoning: str +): ask_feedback( + retrieved_tables, changed_sql_generation_reasoning, st.session_state["asks_results"]["response"][0]["sql"], ) @@ -226,7 +229,11 @@ def show_asks_results(): st.markdown(f"{st.session_state['query']}") st.markdown("### Retrieved Tables") - st.markdown(st.session_state["retrieved_tables"]) + retrieved_tables = st.text_input( + "Enter the retrieved tables separated by commas, ex: table1, table2, table3", + st.session_state["retrieved_tables"], + key="retrieved_tables_input", + ) st.markdown("### SQL Generation Reasoning") changed_sql_generation_reasoning = st.text_area( @@ -240,7 +247,10 @@ def show_asks_results(): st.button( "Regenerate SQL", on_click=on_click_regenerate_sql, - args=(changed_sql_generation_reasoning,), + args=( + retrieved_tables.split(", "), + changed_sql_generation_reasoning, + ), ) st.markdown("### SQL Query Result") @@ -306,84 +316,6 @@ def show_asks_details_results(): sqls_with_cte.append(f"{step['cte_name']} AS ( {step['sql']} )") -def on_click_preview_data_button(index: int, full_sqls: List[str]): - st.session_state["preview_data_button_index"] = index - st.session_state["preview_sql"] = full_sqls[index] - - -def on_change_user_correction( - step_idx: int, explanation_index: int, explanation_result: dict -): - def _get_decision_point(explanation_result: dict): - if explanation_result["type"] == "relation": - if explanation_result["payload"]["type"] == "TABLE": - return { - "type": explanation_result["type"], - "value": explanation_result["payload"]["tableName"], - } - elif explanation_result["payload"]["type"].endswith("_JOIN"): - return { - "type": explanation_result["type"], - "value": explanation_result["payload"]["criteria"], - } - elif explanation_result["type"] == "filter": - return { - "type": explanation_result["type"], - "value": explanation_result["payload"]["expression"], - } - elif explanation_result["type"] == "groupByKeys": - return { - "type": explanation_result["type"], - "value": explanation_result["payload"]["keys"], - } - elif explanation_result["type"] == "sortings": - return { - "type": explanation_result["type"], - "value": explanation_result["payload"]["expression"], - } - elif explanation_result["type"] == "selectItems": - return { - "type": explanation_result["type"], - "value": explanation_result["payload"]["expression"], - } - - decision_point = _get_decision_point(explanation_result) - - should_add_new_correction = True - for i, sql_user_correction in enumerate( - st.session_state["sql_user_corrections_by_step"][step_idx] - ): - if sql_user_correction["before"] == decision_point: - if st.session_state[f"user_correction_{step_idx}_{explanation_index}"]: - st.session_state["sql_user_corrections_by_step"][step_idx][i][ - "after" - ] = { - "type": "nl_expression", - "value": st.session_state[ - f"user_correction_{step_idx}_{explanation_index}" - ], - } - should_add_new_correction = False - break - else: - st.session_state["sql_user_corrections_by_step"][step_idx].pop(i) - should_add_new_correction = False - break - - if should_add_new_correction: - st.session_state["sql_user_corrections_by_step"][step_idx].append( - { - "before": decision_point, - "after": { - "type": "nl_expression", - "value": st.session_state[ - f"user_correction_{step_idx}_{explanation_index}" - ], - }, - } - ) - - def on_click_adjust_chart( query: str, sql: str, @@ -598,10 +530,11 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None): ) -def ask_feedback(sql_generation_reasoning: str, sql: str): +def ask_feedback(tables: list[str], sql_generation_reasoning: str, sql: str): ask_feedback_response = requests.post( f"{WREN_AI_SERVICE_BASE_URL}/v1/ask-feedbacks", json={ + "tables": tables, "sql_generation_reasoning": sql_generation_reasoning, "sql": sql, "configurations": { diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index cc498a88f..4160d2df7 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -23,9 +23,10 @@ sql_regeneration_system_prompt = f""" ### TASK ### -You are a great ANSI SQL expert. Now you are given a SQL generation reasoning and an original SQL query, +You are a great ANSI SQL expert. Now you are given database schema, SQL generation reasoning and an original SQL query, please carefully review the reasoning, and then generate a new SQL query that matches the reasoning. While generating the new SQL query, you should use the original SQL query as a reference. +While generating the new SQL query, make sure to use the database schema to generate the SQL query. {TEXT_TO_SQL_RULES} @@ -38,6 +39,11 @@ """ sql_regeneration_user_prompt_template = """ +### DATABASE SCHEMA ### +{% for document in documents %} + {{ document }} +{% endfor %} + {% if instructions %} ### INSTRUCTIONS ### {{ instructions }} @@ -54,6 +60,7 @@ ## Start of Pipeline @observe(capture_input=False) def prompt( + documents: list[str], sql_generation_reasoning: str, sql: str, prompt_builder: PromptBuilder, @@ -63,6 +70,7 @@ def prompt( ) -> dict: return prompt_builder.run( sql=sql, + documents=documents, sql_generation_reasoning=sql_generation_reasoning, instructions=construct_instructions( configuration, @@ -129,6 +137,7 @@ def __init__( @observe(name="SQL Regeneration") async def run( self, + contexts: list[str], sql_generation_reasoning: str, sql: str, configuration: Configuration = Configuration(), @@ -140,6 +149,7 @@ async def run( return await self._pipe.execute( ["post_process"], inputs={ + "documents": contexts, "sql_generation_reasoning": sql_generation_reasoning, "sql": sql, "project_id": project_id, diff --git a/wren-ai-service/src/pipelines/retrieval/retrieval.py b/wren-ai-service/src/pipelines/retrieval/retrieval.py index fbfda607b..e5ad0e1a3 100644 --- a/wren-ai-service/src/pipelines/retrieval/retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/retrieval.py @@ -118,20 +118,25 @@ def _build_view_ddl(content: dict) -> str: async def embedding( query: str, embedder: Any, history: Optional[AskHistory] = None ) -> dict: - if history: - previous_query_summaries = [ - step.summary for step in history.steps if step.summary - ] - else: - previous_query_summaries = [] + if query: + if history: + previous_query_summaries = [ + step.summary for step in history.steps if step.summary + ] + else: + previous_query_summaries = [] - query = "\n".join(previous_query_summaries) + "\n" + query + query = "\n".join(previous_query_summaries) + "\n" + query - return await embedder.run(query) + return await embedder.run(query) + else: + return {} @observe(capture_input=False) -async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dict: +async def table_retrieval( + embedding: dict, id: str, tables: list[str], table_retriever: Any +) -> dict: filters = { "operator": "AND", "conditions": [ @@ -144,15 +149,25 @@ async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dic {"field": "project_id", "operator": "==", "value": id} ) - return await table_retriever.run( - query_embedding=embedding.get("embedding"), - filters=filters, - ) + if embedding: + return await table_retriever.run( + query_embedding=embedding.get("embedding"), + filters=filters, + ) + else: + filters["conditions"].append( + {"field": "name", "operator": "in", "value": tables} + ) + + return await table_retriever.run( + query_embedding=[], + filters=filters, + ) @observe(capture_input=False) async def dbschema_retrieval( - table_retrieval: dict, embedding: dict, id: str, dbschema_retriever: Any + table_retrieval: dict, id: str, dbschema_retriever: Any ) -> list[Document]: tables = table_retrieval.get("documents", []) table_names = [] @@ -178,9 +193,7 @@ async def dbschema_retrieval( {"field": "project_id", "operator": "==", "value": id} ) - results = await dbschema_retriever.run( - query_embedding=embedding.get("embedding"), filters=filters - ) + results = await dbschema_retriever.run(query_embedding=[], filters=filters) return results["documents"] @@ -466,7 +479,8 @@ def __init__( @observe(name="Ask Retrieval") async def run( self, - query: str, + query: str = "", + tables: Optional[list[str]] = None, id: Optional[str] = None, history: Optional[AskHistory] = None, ): @@ -475,6 +489,7 @@ async def run( ["construct_retrieval_results"], inputs={ "query": query, + "tables": tables, "id": id or "", "history": history, **self._components, diff --git a/wren-ai-service/src/providers/document_store/qdrant.py b/wren-ai-service/src/providers/document_store/qdrant.py index 13a740470..d471e788d 100644 --- a/wren-ai-service/src/providers/document_store/qdrant.py +++ b/wren-ai-service/src/providers/document_store/qdrant.py @@ -209,6 +209,36 @@ async def _query_by_embedding( document.score = score return results + async def _query_by_filters( + self, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ) -> List[Document]: + qdrant_filters = convert_filters_to_qdrant(filters) + points_list = [] + offset = None + while True: + points = await self.async_client.scroll( + collection_name=self.index, + offset=offset, + scroll_filter=qdrant_filters, + limit=top_k, + ) + points_list.extend(points[0]) + if points[1] is None: + break + offset = points[1] + + if points_list: + return [ + convert_qdrant_point_to_haystack_document( + point, use_sparse_embeddings=self.use_sparse_embeddings + ) + for point in points_list + ] + else: + return [] + async def delete_documents(self, filters: Optional[Dict[str, Any]] = None): if not filters: qdrant_filters = rest.Filter() @@ -306,6 +336,7 @@ def __init__( scale_score=scale_score, return_embedding=return_embedding, ) + self._document_store = document_store @component.output_types(documents=List[Document]) async def run( @@ -316,13 +347,19 @@ async def run( scale_score: Optional[bool] = None, return_embedding: Optional[bool] = None, ): - docs = await self._document_store._query_by_embedding( - query_embedding=query_embedding, - filters=filters or self._filters, - top_k=top_k or self._top_k, - scale_score=scale_score or self._scale_score, - return_embedding=return_embedding or self._return_embedding, - ) + if query_embedding: + docs = await self._document_store._query_by_embedding( + query_embedding=query_embedding, + filters=filters or self._filters, + top_k=top_k or self._top_k, + scale_score=scale_score or self._scale_score, + return_embedding=return_embedding or self._return_embedding, + ) + else: + docs = await self._document_store._query_by_filters( + filters=filters, + top_k=top_k, + ) return {"documents": docs} diff --git a/wren-ai-service/src/web/v1/routers/ask.py b/wren-ai-service/src/web/v1/routers/ask.py index 7998787e6..7eb2f2cdd 100644 --- a/wren-ai-service/src/web/v1/routers/ask.py +++ b/wren-ai-service/src/web/v1/routers/ask.py @@ -155,7 +155,7 @@ async def ask_feedback( service_container.ask_service._ask_feedback_results[ query_id ] = AskFeedbackResultResponse( - status="understanding", + status="searching", ) background_tasks.add_task( diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 30f5df14f..088b3538b 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -102,6 +102,7 @@ class AskResultResponse(BaseModel): # POST /v1/ask-feedbacks class AskFeedbackRequest(BaseModel): _query_id: str | None = None + tables: List[str] sql_generation_reasoning: str sql: str project_id: Optional[str] = None @@ -145,7 +146,7 @@ class AskFeedbackResultRequest(BaseModel): class AskFeedbackResultResponse(BaseModel): status: Literal[ - "understanding", + "searching", "generating", "correcting", "finished", @@ -586,8 +587,18 @@ async def ask_feedback( try: if not self._is_stopped(query_id, self._ask_feedback_results): self._ask_feedback_results[query_id] = AskFeedbackResultResponse( - status="understanding", + status="searching", + ) + + retrieval_result = await self._pipelines["retrieval"].run( + tables=ask_feedback_request.tables, + id=ask_feedback_request.project_id, + ) + _retrieval_result = retrieval_result.get( + "construct_retrieval_results", {} ) + documents = _retrieval_result.get("retrieval_results", []) + table_ddls = [document.get("table_ddl") for document in documents] if not self._is_stopped(query_id, self._ask_feedback_results): self._ask_feedback_results[query_id] = AskFeedbackResultResponse( @@ -597,6 +608,7 @@ async def ask_feedback( text_to_sql_generation_results = await self._pipelines[ "sql_regeneration" ].run( + contexts=table_ddls, sql_generation_reasoning=ask_feedback_request.sql_generation_reasoning, sql=ask_feedback_request.sql, project_id=ask_feedback_request.project_id,