diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 7ba14a5fc..b05c51a5a 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -12,6 +12,7 @@ from src.core.provider import LLMProvider from src.pipelines.generation.utils.sql import ( SQL_GENERATION_MODEL_KWARGS, + TEXT_TO_SQL_RULES, SQLGenPostProcessor, construct_instructions, ) @@ -20,10 +21,33 @@ logger = logging.getLogger("wren-ai-service") -sql_regeneration_system_prompt = """ +sql_regeneration_system_prompt = f""" +### TASK ### +You are a great ANSI SQL expert. Now you are given a SQL generation reasoning and an original SQL query, +please carefully review the reasoning, and then generate a new SQL query that matches the reasoning. +While generating the new SQL query, you should use the original SQL query as a reference. + +{TEXT_TO_SQL_RULES} + +### FINAL ANSWER FORMAT ### +The final answer must be a ANSI SQL query in JSON format: + +{{ + "sql": +}} """ sql_regeneration_user_prompt_template = """ +{% if instructions %} +### INSTRUCTIONS ### +{{ instructions }} +{% endif %} + +### QUESTION ### +SQL generation reasoning: {{ sql_generation_reasoning }} +Original SQL query: {{ sql }} + +Let's think step by step. """ @@ -51,7 +75,7 @@ def prompt( @observe(as_type="generation", capture_input=False) -async def generate_sql( +async def regenerate_sql( prompt: dict, generator: Any, ) -> dict: @@ -60,13 +84,13 @@ async def generate_sql( @observe(capture_input=False) async def post_process( - generate_sql: dict, + regenerate_sql: dict, post_processor: SQLGenPostProcessor, engine_timeout: float, project_id: str | None = None, ) -> dict: return await post_processor.run( - generate_sql.get("replies"), + regenerate_sql.get("replies"), timeout=engine_timeout, project_id=project_id, ) diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index dee4334bd..3e8ebff8b 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -351,14 +351,14 @@ async def _task(sql: str): """ sql_generation_system_prompt = f""" -You are a helpful assistant that converts natural language queries into SQL queries. +You are a helpful assistant that converts natural language queries into ANSI SQL queries. Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. {TEXT_TO_SQL_RULES} ### FINAL ANSWER FORMAT ### -The final answer must be a SQL query in JSON format: +The final answer must be a ANSI SQL query in JSON format: {{ "sql":