Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Feb 7, 2025
1 parent 9c0cda1 commit c64f9da
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
32 changes: 28 additions & 4 deletions wren-ai-service/src/pipelines/generation/sql_regeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from src.core.provider import LLMProvider
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
TEXT_TO_SQL_RULES,
SQLGenPostProcessor,
construct_instructions,
)
Expand All @@ -20,10 +21,33 @@
logger = logging.getLogger("wren-ai-service")


sql_regeneration_system_prompt = """
sql_regeneration_system_prompt = f"""
### TASK ###
You are a great ANSI SQL expert. Now you are given a SQL generation reasoning and an original SQL query,
please carefully review the reasoning, and then generate a new SQL query that matches the reasoning.
While generating the new SQL query, you should use the original SQL query as a reference.
{TEXT_TO_SQL_RULES}
### FINAL ANSWER FORMAT ###
The final answer must be a ANSI SQL query in JSON format:
{{
"sql": <SQL_QUERY_STRING>
}}
"""

sql_regeneration_user_prompt_template = """
{% if instructions %}
### INSTRUCTIONS ###
{{ instructions }}
{% endif %}
### QUESTION ###
SQL generation reasoning: {{ sql_generation_reasoning }}
Original SQL query: {{ sql }}
Let's think step by step.
"""


Expand Down Expand Up @@ -51,7 +75,7 @@ def prompt(


@observe(as_type="generation", capture_input=False)
async def generate_sql(
async def regenerate_sql(
prompt: dict,
generator: Any,
) -> dict:
Expand All @@ -60,13 +84,13 @@ async def generate_sql(

@observe(capture_input=False)
async def post_process(
generate_sql: dict,
regenerate_sql: dict,
post_processor: SQLGenPostProcessor,
engine_timeout: float,
project_id: str | None = None,
) -> dict:
return await post_processor.run(
generate_sql.get("replies"),
regenerate_sql.get("replies"),
timeout=engine_timeout,
project_id=project_id,
)
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/utils/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,14 @@ async def _task(sql: str):
"""

sql_generation_system_prompt = f"""
You are a helpful assistant that converts natural language queries into SQL queries.
You are a helpful assistant that converts natural language queries into ANSI SQL queries.
Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
{TEXT_TO_SQL_RULES}
### FINAL ANSWER FORMAT ###
The final answer must be a SQL query in JSON format:
The final answer must be a ANSI SQL query in JSON format:
{{
"sql": <SQL_QUERY_STRING>
Expand Down

0 comments on commit c64f9da

Please sign in to comment.