Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): allow regenerate sql using retrieved tables #1324

Merged
merged 1 commit into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 16 additions & 83 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import uuid
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import orjson
import pandas as pd
Expand Down Expand Up @@ -183,8 +183,11 @@ def on_change_sql_generation_reasoning():
]


def on_click_regenerate_sql(changed_sql_generation_reasoning: str):
def on_click_regenerate_sql(
retrieved_tables: list[str], changed_sql_generation_reasoning: str
):
ask_feedback(
retrieved_tables,
changed_sql_generation_reasoning,
st.session_state["asks_results"]["response"][0]["sql"],
)
Expand Down Expand Up @@ -226,7 +229,11 @@ def show_asks_results():
st.markdown(f"{st.session_state['query']}")

st.markdown("### Retrieved Tables")
st.markdown(st.session_state["retrieved_tables"])
retrieved_tables = st.text_input(
"Enter the retrieved tables separated by commas, ex: table1, table2, table3",
st.session_state["retrieved_tables"],
key="retrieved_tables_input",
)

st.markdown("### SQL Generation Reasoning")
changed_sql_generation_reasoning = st.text_area(
Expand All @@ -240,7 +247,10 @@ def show_asks_results():
st.button(
"Regenerate SQL",
on_click=on_click_regenerate_sql,
args=(changed_sql_generation_reasoning,),
args=(
retrieved_tables.split(", "),
changed_sql_generation_reasoning,
),
)

st.markdown("### SQL Query Result")
Expand Down Expand Up @@ -306,84 +316,6 @@ def show_asks_details_results():
sqls_with_cte.append(f"{step['cte_name']} AS ( {step['sql']} )")


def on_click_preview_data_button(index: int, full_sqls: List[str]):
st.session_state["preview_data_button_index"] = index
st.session_state["preview_sql"] = full_sqls[index]


def on_change_user_correction(
step_idx: int, explanation_index: int, explanation_result: dict
):
def _get_decision_point(explanation_result: dict):
if explanation_result["type"] == "relation":
if explanation_result["payload"]["type"] == "TABLE":
return {
"type": explanation_result["type"],
"value": explanation_result["payload"]["tableName"],
}
elif explanation_result["payload"]["type"].endswith("_JOIN"):
return {
"type": explanation_result["type"],
"value": explanation_result["payload"]["criteria"],
}
elif explanation_result["type"] == "filter":
return {
"type": explanation_result["type"],
"value": explanation_result["payload"]["expression"],
}
elif explanation_result["type"] == "groupByKeys":
return {
"type": explanation_result["type"],
"value": explanation_result["payload"]["keys"],
}
elif explanation_result["type"] == "sortings":
return {
"type": explanation_result["type"],
"value": explanation_result["payload"]["expression"],
}
elif explanation_result["type"] == "selectItems":
return {
"type": explanation_result["type"],
"value": explanation_result["payload"]["expression"],
}

decision_point = _get_decision_point(explanation_result)

should_add_new_correction = True
for i, sql_user_correction in enumerate(
st.session_state["sql_user_corrections_by_step"][step_idx]
):
if sql_user_correction["before"] == decision_point:
if st.session_state[f"user_correction_{step_idx}_{explanation_index}"]:
st.session_state["sql_user_corrections_by_step"][step_idx][i][
"after"
] = {
"type": "nl_expression",
"value": st.session_state[
f"user_correction_{step_idx}_{explanation_index}"
],
}
should_add_new_correction = False
break
else:
st.session_state["sql_user_corrections_by_step"][step_idx].pop(i)
should_add_new_correction = False
break

if should_add_new_correction:
st.session_state["sql_user_corrections_by_step"][step_idx].append(
{
"before": decision_point,
"after": {
"type": "nl_expression",
"value": st.session_state[
f"user_correction_{step_idx}_{explanation_index}"
],
},
}
)


def on_click_adjust_chart(
query: str,
sql: str,
Expand Down Expand Up @@ -598,10 +530,11 @@ def ask(query: str, timezone: str, query_history: Optional[dict] = None):
)


def ask_feedback(sql_generation_reasoning: str, sql: str):
def ask_feedback(tables: list[str], sql_generation_reasoning: str, sql: str):
ask_feedback_response = requests.post(
f"{WREN_AI_SERVICE_BASE_URL}/v1/ask-feedbacks",
json={
"tables": tables,
"sql_generation_reasoning": sql_generation_reasoning,
"sql": sql,
"configurations": {
Expand Down
12 changes: 11 additions & 1 deletion wren-ai-service/src/pipelines/generation/sql_regeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

sql_regeneration_system_prompt = f"""
### TASK ###
You are a great ANSI SQL expert. Now you are given a SQL generation reasoning and an original SQL query,
You are a great ANSI SQL expert. Now you are given database schema, SQL generation reasoning and an original SQL query,
please carefully review the reasoning, and then generate a new SQL query that matches the reasoning.
While generating the new SQL query, you should use the original SQL query as a reference.
While generating the new SQL query, make sure to use the database schema to generate the SQL query.

{TEXT_TO_SQL_RULES}

Expand All @@ -38,6 +39,11 @@
"""

sql_regeneration_user_prompt_template = """
### DATABASE SCHEMA ###
{% for document in documents %}
{{ document }}
{% endfor %}

{% if instructions %}
### INSTRUCTIONS ###
{{ instructions }}
Expand All @@ -54,6 +60,7 @@
## Start of Pipeline
@observe(capture_input=False)
def prompt(
documents: list[str],
sql_generation_reasoning: str,
sql: str,
prompt_builder: PromptBuilder,
Expand All @@ -63,6 +70,7 @@ def prompt(
) -> dict:
return prompt_builder.run(
sql=sql,
documents=documents,
sql_generation_reasoning=sql_generation_reasoning,
instructions=construct_instructions(
configuration,
Expand Down Expand Up @@ -129,6 +137,7 @@ def __init__(
@observe(name="SQL Regeneration")
async def run(
self,
contexts: list[str],
sql_generation_reasoning: str,
sql: str,
configuration: Configuration = Configuration(),
Expand All @@ -140,6 +149,7 @@ async def run(
return await self._pipe.execute(
["post_process"],
inputs={
"documents": contexts,
"sql_generation_reasoning": sql_generation_reasoning,
"sql": sql,
"project_id": project_id,
Expand Down
51 changes: 33 additions & 18 deletions wren-ai-service/src/pipelines/retrieval/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,25 @@ def _build_view_ddl(content: dict) -> str:
async def embedding(
query: str, embedder: Any, history: Optional[AskHistory] = None
) -> dict:
if history:
previous_query_summaries = [
step.summary for step in history.steps if step.summary
]
else:
previous_query_summaries = []
if query:
if history:
previous_query_summaries = [
step.summary for step in history.steps if step.summary
]
else:
previous_query_summaries = []

query = "\n".join(previous_query_summaries) + "\n" + query
query = "\n".join(previous_query_summaries) + "\n" + query

return await embedder.run(query)
return await embedder.run(query)
else:
return {}


@observe(capture_input=False)
async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dict:
async def table_retrieval(
embedding: dict, id: str, tables: list[str], table_retriever: Any
) -> dict:
filters = {
"operator": "AND",
"conditions": [
Expand All @@ -144,15 +149,25 @@ async def table_retrieval(embedding: dict, id: str, table_retriever: Any) -> dic
{"field": "project_id", "operator": "==", "value": id}
)

return await table_retriever.run(
query_embedding=embedding.get("embedding"),
filters=filters,
)
if embedding:
return await table_retriever.run(
query_embedding=embedding.get("embedding"),
filters=filters,
)
else:
filters["conditions"].append(
{"field": "name", "operator": "in", "value": tables}
)

return await table_retriever.run(
query_embedding=[],
filters=filters,
)


@observe(capture_input=False)
async def dbschema_retrieval(
table_retrieval: dict, embedding: dict, id: str, dbschema_retriever: Any
table_retrieval: dict, id: str, dbschema_retriever: Any
) -> list[Document]:
tables = table_retrieval.get("documents", [])
table_names = []
Expand All @@ -178,9 +193,7 @@ async def dbschema_retrieval(
{"field": "project_id", "operator": "==", "value": id}
)

results = await dbschema_retriever.run(
query_embedding=embedding.get("embedding"), filters=filters
)
results = await dbschema_retriever.run(query_embedding=[], filters=filters)
return results["documents"]


Expand Down Expand Up @@ -466,7 +479,8 @@ def __init__(
@observe(name="Ask Retrieval")
async def run(
self,
query: str,
query: str = "",
tables: Optional[list[str]] = None,
id: Optional[str] = None,
history: Optional[AskHistory] = None,
):
Expand All @@ -475,6 +489,7 @@ async def run(
["construct_retrieval_results"],
inputs={
"query": query,
"tables": tables,
"id": id or "",
"history": history,
**self._components,
Expand Down
51 changes: 44 additions & 7 deletions wren-ai-service/src/providers/document_store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,36 @@ async def _query_by_embedding(
document.score = score
return results

async def _query_by_filters(
self,
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
) -> List[Document]:
qdrant_filters = convert_filters_to_qdrant(filters)
points_list = []
offset = None
while True:
points = await self.async_client.scroll(
collection_name=self.index,
offset=offset,
scroll_filter=qdrant_filters,
limit=top_k,
)
points_list.extend(points[0])
if points[1] is None:
break
offset = points[1]
Comment on lines +220 to +230
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I executed the code and found that None is present at index 1, allowing it to satisfy the condition and exit the while loop. However, I suggest we clarify this in the docstring for future reference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, the while loop w/ an always True condition poses a potential risk of an infinite loop. we are better to change the condition from my opinion.


if points_list:
return [
convert_qdrant_point_to_haystack_document(
point, use_sparse_embeddings=self.use_sparse_embeddings
)
for point in points_list
]
else:
return []

async def delete_documents(self, filters: Optional[Dict[str, Any]] = None):
if not filters:
qdrant_filters = rest.Filter()
Expand Down Expand Up @@ -306,6 +336,7 @@ def __init__(
scale_score=scale_score,
return_embedding=return_embedding,
)
self._document_store = document_store

@component.output_types(documents=List[Document])
async def run(
Expand All @@ -316,13 +347,19 @@ async def run(
scale_score: Optional[bool] = None,
return_embedding: Optional[bool] = None,
):
docs = await self._document_store._query_by_embedding(
query_embedding=query_embedding,
filters=filters or self._filters,
top_k=top_k or self._top_k,
scale_score=scale_score or self._scale_score,
return_embedding=return_embedding or self._return_embedding,
)
if query_embedding:
docs = await self._document_store._query_by_embedding(
query_embedding=query_embedding,
filters=filters or self._filters,
top_k=top_k or self._top_k,
scale_score=scale_score or self._scale_score,
return_embedding=return_embedding or self._return_embedding,
)
else:
docs = await self._document_store._query_by_filters(
filters=filters,
top_k=top_k,
)

return {"documents": docs}

Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/web/v1/routers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def ask_feedback(
service_container.ask_service._ask_feedback_results[
query_id
] = AskFeedbackResultResponse(
status="understanding",
status="searching",
)

background_tasks.add_task(
Expand Down
Loading
Loading