Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Feb 6, 2025
1 parent 327f6da commit 3166c53
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 4 deletions.
5 changes: 3 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
### TASK ###
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills.
Now you are given syntactically incorrect ANSI SQL query and related error message.
With given database schema, please generate the syntactically correct ANSI SQL query without changing original semantics.
Now you are given syntactically incorrect ANSI SQL query and related error message, please generate the syntactically correct ANSI SQL query without changing original semantics.
{TEXT_TO_SQL_RULES}
Expand All @@ -39,10 +38,12 @@
"""

sql_correction_user_prompt_template = """
{% if documents %}
### DATABASE SCHEMA ###
{% for document in documents %}
{{ document }}
{% endfor %}
{% endif %}
### QUESTION ###
SQL: {{ invalid_generation_result.sql }}
Expand Down
87 changes: 85 additions & 2 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ async def ask(
if failed_dry_run_results[0]["type"] != "TIME_OUT":
self._ask_results[query_id] = AskResultResponse(
status="correcting",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
sql_generation_reasoning=sql_generation_reasoning,
)
sql_correction_results = await self._pipelines[
"sql_correction"
Expand Down Expand Up @@ -560,8 +564,8 @@ async def ask_feedback(
}

query_id = ask_feedback_request.query_id
# api_results = []
# error_message = ""
api_results = []
error_message = ""

try:
if not self._is_stopped(query_id, self._ask_feedback_results):
Expand All @@ -574,6 +578,85 @@ async def ask_feedback(
status="generating",
)

text_to_sql_generation_results = await self._pipelines[
"sql_regeneration"
].run(
sql_generation_reasoning=ask_feedback_request.sql_generation_reasoning,
sql=ask_feedback_request.sql,
project_id=ask_feedback_request.project_id,
configuration=ask_feedback_request.configurations,
)

if sql_valid_results := text_to_sql_generation_results["post_process"][
"valid_generation_results"
]:
api_results = [
AskResult(
**{
"sql": result.get("sql"),
"type": "llm",
}
)
for result in sql_valid_results
][:1]
elif failed_dry_run_results := text_to_sql_generation_results[
"post_process"
]["invalid_generation_results"]:
if failed_dry_run_results[0]["type"] != "TIME_OUT":
self._ask_feedback_results[
query_id
] = AskFeedbackResultResponse(
status="correcting",
)
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=[],
invalid_generation_results=failed_dry_run_results,
project_id=ask_feedback_request.project_id,
)

if valid_generation_results := sql_correction_results[
"post_process"
]["valid_generation_results"]:
api_results = [
AskResult(
**{
"sql": valid_generation_result.get("sql"),
"type": "llm",
}
)
for valid_generation_result in valid_generation_results
][:1]
elif failed_dry_run_results := sql_correction_results[
"post_process"
]["invalid_generation_results"]:
error_message = failed_dry_run_results[0]["error"]
else:
error_message = failed_dry_run_results[0]["error"]

if api_results:
if not self._is_stopped(query_id, self._ask_feedback_results):
self._ask_feedback_results[query_id] = AskFeedbackResultResponse(
status="finished",
response=api_results,
)
results["ask_feedback_result"] = api_results
else:
logger.exception("ask feedback pipeline - NO_RELEVANT_SQL")
if not self._is_stopped(query_id, self._ask_feedback_results):
self._ask_feedback_results[query_id] = AskFeedbackResultResponse(
status="failed",
error=AskError(
code="NO_RELEVANT_SQL",
message=error_message or "No relevant SQL",
),
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["error_message"] = error_message

return results

except Exception as e:
logger.exception(f"ask feedback pipeline - OTHERS: {e}")

Expand Down

0 comments on commit 3166c53

Please sign in to comment.