diff --git a/wren-ai-service/Justfile b/wren-ai-service/Justfile index 1fb8fbce9..c0e8683aa 100644 --- a/wren-ai-service/Justfile +++ b/wren-ai-service/Justfile @@ -32,8 +32,8 @@ start: use-wren-ui-as-engine curate_eval_data: poetry run streamlit run eval/data_curation/app.py -prep: - poetry run python -m eval.preparation +prep dataset='spider1.0': + poetry run python -m eval.preparation --dataset {{dataset}} predict dataset pipeline='ask': @poetry run python -u eval/prediction.py --file {{dataset}} --pipeline {{pipeline}} diff --git a/wren-ai-service/eval/__init__.py b/wren-ai-service/eval/__init__.py index 1326364f9..0536e7bee 100644 --- a/wren-ai-service/eval/__init__.py +++ b/wren-ai-service/eval/__init__.py @@ -11,6 +11,7 @@ class EvalSettings(Settings): config_path: str = "eval/config.yaml" openai_api_key: SecretStr = Field(alias="LLM_OPENAI_API_KEY") allow_sql_samples: bool = True + db_path_for_duckdb: str = "" # BigQuery bigquery_project_id: str = Field(default="") diff --git a/wren-ai-service/eval/data_curation/app.py b/wren-ai-service/eval/data_curation/app.py index 56ab532c2..615442abf 100644 --- a/wren-ai-service/eval/data_curation/app.py +++ b/wren-ai-service/eval/data_curation/app.py @@ -13,14 +13,6 @@ from streamlit_tags import st_tags sys.path.append(f"{Path().parent.resolve()}") -from eval import EvalSettings -from eval.utils import ( - get_documents_given_contexts, - get_eval_dataset_in_toml_string, - get_openai_client, - prepare_duckdb_init_sql, - prepare_duckdb_session_sql, -) from utils import ( DATA_SOURCES, WREN_ENGINE_ENDPOINT, @@ -32,6 +24,15 @@ prettify_sql, ) +from eval import EvalSettings +from eval.utils import ( + get_documents_given_contexts, + get_eval_dataset_in_toml_string, + get_openai_client, + prepare_duckdb_init_sql, + prepare_duckdb_session_sql, +) + st.set_page_config(layout="wide") st.title("WrenAI Data Curation App") @@ -66,9 +67,9 @@ def on_change_upload_eval_dataset(): doc = tomlkit.parse(st.session_state.uploaded_eval_file.getvalue().decode("utf-8")) - assert doc["mdl"] == st.session_state["mdl_json"], ( - "The model in the uploaded dataset is different from the deployed model" - ) + assert ( + doc["mdl"] == st.session_state["mdl_json"] + ), "The model in the uploaded dataset is different from the deployed model" st.session_state["candidate_dataset"] = doc["eval_dataset"] @@ -116,7 +117,9 @@ def on_click_setup_uploaded_file(): elif data_source == "duckdb": prepare_duckdb_session_sql(WREN_ENGINE_ENDPOINT) prepare_duckdb_init_sql( - WREN_ENGINE_ENDPOINT, st.session_state["mdl_json"]["catalog"] + WREN_ENGINE_ENDPOINT, + st.session_state["mdl_json"]["catalog"], + "etc/spider1.0/database", ) else: st.session_state["data_source"] = None diff --git a/wren-ai-service/eval/pipelines.py b/wren-ai-service/eval/pipelines.py index dce385f6d..4c1920486 100644 --- a/wren-ai-service/eval/pipelines.py +++ b/wren-ai-service/eval/pipelines.py @@ -115,7 +115,8 @@ async def wrapper(batch: list): return [prediction for batch in batches for prediction in batch] @abstractmethod - def _process(self, prediction: dict, **_) -> dict: ... + def _process(self, prediction: dict, **_) -> dict: + ... async def _flat(self, prediction: dict, **_) -> dict: """ @@ -247,7 +248,9 @@ def __init__( ) self._allow_sql_samples = settings.allow_sql_samples - self._engine_info = engine_config(mdl, pipe_components) + self._engine_info = engine_config( + mdl, pipe_components, settings.db_path_for_duckdb + ) async def _flat(self, prediction: dict, actual: str) -> dict: prediction["actual_output"] = actual @@ -338,7 +341,9 @@ def __init__( ) self._allow_sql_samples = settings.allow_sql_samples - self._engine_info = engine_config(mdl, pipe_components) + self._engine_info = engine_config( + mdl, pipe_components, settings.db_path_for_duckdb + ) async def _flat(self, prediction: dict, actual: str) -> dict: prediction["actual_output"] = actual diff --git a/wren-ai-service/eval/prediction.py b/wren-ai-service/eval/prediction.py index 442b9dbf3..6146d0419 100644 --- a/wren-ai-service/eval/prediction.py +++ b/wren-ai-service/eval/prediction.py @@ -109,6 +109,13 @@ def parse_args() -> Tuple[str, str]: _mdl = base64.b64encode(orjson.dumps(dataset["mdl"])).decode("utf-8") if "spider_" in path: settings.datasource = "duckdb" + settings.db_path_for_duckdb = "etc/spider1.0/database" + replace_wren_engine_env_variables( + "wren_engine", {"manifest": _mdl}, settings.config_path + ) + elif "bird_" in path: + settings.datasource = "duckdb" + settings.db_path_for_duckdb = "etc/bird/minidev/MINIDEV/dev_databases" replace_wren_engine_env_variables( "wren_engine", {"manifest": _mdl}, settings.config_path ) diff --git a/wren-ai-service/eval/preparation.py b/wren-ai-service/eval/preparation.py index feb5231d1..b07b3078c 100644 --- a/wren-ai-service/eval/preparation.py +++ b/wren-ai-service/eval/preparation.py @@ -1,15 +1,18 @@ """ -This file aims to prepare spider 1.0 eval dataset for text-to-sql eval purpose +This file aims to prepare spider 1.0 or bird eval dataset for text-to-sql eval purpose """ +import argparse import asyncio import os import zipfile from collections import defaultdict from itertools import zip_longest from pathlib import Path +from urllib.request import urlretrieve import gdown import orjson +import pandas as pd from eval.utils import ( get_contexts_from_sql, @@ -21,6 +24,7 @@ ) SPIDER_DESTINATION_PATH = Path("./tools/dev/etc/spider1.0") +BIRD_DESTINATION_PATH = Path("./tools/dev/etc/bird") WREN_ENGINE_API_URL = "http://localhost:8080" EVAL_DATASET_DESTINATION_PATH = Path("./eval/dataset") @@ -57,35 +61,59 @@ def _download_and_extract( ) +def download_bird_data(destination_path: Path): + def _download_and_extract(destination_path: Path, path: Path, file_name: str): + if not (destination_path / path).exists(): + if Path(file_name).exists(): + os.remove(file_name) + + url = "https://bird-bench.oss-cn-beijing.aliyuncs.com/minidev.zip" + + print(f"Downloading {file_name} from {url}...") + urlretrieve(url, file_name) + + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(destination_path) + + os.remove(file_name) + + _download_and_extract( + destination_path, + "minidev", + "minidev.zip", + ) + + def get_database_names(path: Path): return [folder.name for folder in path.iterdir() if folder.is_dir()] -def build_mdl_by_db(destination_path: Path): - def _get_tables_by_db(path: Path, key: str): - with open(path, "rb") as f: - json_data = orjson.loads(f.read()) +def get_tables_by_db(path: Path, key: str): + with open(path, "rb") as f: + json_data = orjson.loads(f.read()) - return {item[key]: item for item in json_data} - - def _merge_column_info(column_names_original, column_types): - merged_info = [] - for (table_index, column_name), column_type in zip( - column_names_original, column_types - ): - merged_info.append( - { - "table_index": table_index, - "column_name": column_name, - "column_type": column_type, - } - ) - return merged_info + return {item[key]: item for item in json_data} + + +def build_mdl_models(database, tables_info, database_info={}): + def _build_mdl_columns(tables_info, table_index, table_info=None): + def _merge_column_info(column_names_original, column_types): + merged_info = [] + for (table_index, column_name), column_type in zip( + column_names_original, column_types + ): + merged_info.append( + { + "table_index": table_index, + "column_name": column_name, + "column_type": column_type, + } + ) + return merged_info - def _get_columns_by_table_index(columns, table_index): - return list(filter(lambda col: col["table_index"] == table_index, columns)) + def _get_columns_by_table_index(columns, table_index): + return list(filter(lambda col: col["table_index"] == table_index, columns)) - def _build_mdl_columns(tables_info, table_index): _columns = _get_columns_by_table_index( _merge_column_info( tables_info["column_names_original"], tables_info["column_types"] @@ -93,85 +121,122 @@ def _build_mdl_columns(tables_info, table_index): table_index, ) + columns_info = {} + if table_info: + for column_info in table_info: + original_col_key = next( + key for key in column_info.keys() if "original_column_name" in key + ) + if value_description := column_info.get("value_description", ""): + columns_info[column_info[original_col_key]] = ( + column_info.get("column_description", "") + + ", " + + value_description + ).strip() + else: + columns_info[column_info[original_col_key]] = column_info.get( + "column_description", "" + ).strip() + return [ { "name": column["column_name"], "type": column["column_type"], "notNull": False, - "properties": {}, + "properties": { + "description": columns_info.get(column["column_name"], ""), + } + if columns_info and columns_info.get(column["column_name"], "") + else {}, } for column in _columns ] - def _build_mdl_models(database, tables_info): - return [ - { - "name": table, - "properties": {}, - "tableReference": { - "catalog": database, - "schema": "main", - "table": table, - }, - "primaryKey": tables_info["column_names_original"][ - primary_key_column_index - ][-1] - if primary_key_column_index - else "", - "columns": _build_mdl_columns(tables_info, i), - } - for i, (table, primary_key_column_index) in enumerate( - zip_longest( - tables_info["table_names_original"], tables_info["primary_keys"] - ) + return [ + { + "name": table, + "properties": {}, + "tableReference": { + "catalog": database, + "schema": "main", + "table": table, + }, + "primaryKey": tables_info["column_names_original"][ + primary_key_column_index + ][-1] + if primary_key_column_index + else "", + "columns": _build_mdl_columns( + tables_info, i, database_info.get(table, None) + ), + } + for i, (table, primary_key_column_index) in enumerate( + zip_longest( + tables_info["table_names_original"], + filter( + lambda x: isinstance(x, int), tables_info["primary_keys"] + ), # filter out composite primary keys as of now ) + ) + ] + + +def build_mdl_relationships(tables_info): + relationships = [] + for first, second in tables_info["foreign_keys"]: + first_table_index, first_column_name = tables_info["column_names_original"][ + first ] + first_foreign_key_table = tables_info["table_names_original"][first_table_index] + + second_table_index, second_column_name = tables_info["column_names_original"][ + second + ] + second_foreign_key_table = tables_info["table_names_original"][ + second_table_index + ] + + relationships.append( + { + "name": f"{first_foreign_key_table}_{first_column_name}_{second_foreign_key_table}_{second_column_name}", + "models": [first_foreign_key_table, second_foreign_key_table], + "joinType": "MANY_TO_MANY", + "condition": f"{first_foreign_key_table}.{first_column_name} = {second_foreign_key_table}.{second_column_name}", + } + ) + + return relationships - def _build_mdl_relationships(tables_info): - relationships = [] - for first, second in tables_info["foreign_keys"]: - first_table_index, first_column_name = tables_info["column_names_original"][ - first - ] - first_foreign_key_table = tables_info["table_names_original"][ - first_table_index - ] - - second_table_index, second_column_name = tables_info[ - "column_names_original" - ][second] - second_foreign_key_table = tables_info["table_names_original"][ - second_table_index - ] - - relationships.append( - { - "name": f"{first_foreign_key_table}_{first_column_name}_{second_foreign_key_table}_{second_column_name}", - "models": [first_foreign_key_table, second_foreign_key_table], - "joinType": "MANY_TO_MANY", - "condition": f"{first_foreign_key_table}.{first_column_name} = {second_foreign_key_table}.{second_column_name}", - } - ) - return relationships +def get_ground_truths_by_db(path: Path, key: str): + with open(path, "rb") as f: + json_data = orjson.loads(f.read()) + results = defaultdict(list) + for item in json_data: + results[item[key]].append(item) + + return results + + +def build_mdl_by_db_using_spider(destination_path: Path): # get all database names in the spider testsuite - databases = get_database_names(destination_path / "database") + database_names = get_database_names(destination_path / "database") # read tables.json and transform it to be a dictionary with database name as key - tables_by_db = _get_tables_by_db( + tables_by_db = get_tables_by_db( destination_path / "spider_data/tables.json", "db_id" ) # build mdl for each database by checking the test_tables.json in spider_data mdl_by_db = {} - for database in databases: + for database in database_names: if tables_info := tables_by_db.get(database): mdl_by_db[database] = { "catalog": database, "schema": "main", - "models": _build_mdl_models(database, tables_info), - "relationships": _build_mdl_relationships(tables_info), + "models": build_mdl_models(database, tables_info), + "relationships": build_mdl_relationships(tables_info), "views": [], "metrics": [], } @@ -179,7 +244,7 @@ def _build_mdl_relationships(tables_info): return mdl_by_db -def build_question_sql_pairs_by_db(destination_path: Path): +def build_question_sql_pairs_by_db_using_spider(destination_path: Path): def _get_ground_truths_by_db(path: Path, key: str): with open(path, "rb") as f: json_data = orjson.loads(f.read()) @@ -191,15 +256,15 @@ def _get_ground_truths_by_db(path: Path, key: str): return results # get all database names in the spider testsuite - databases = get_database_names(destination_path / "database") + database_names = get_database_names(destination_path / "database") # get dev.json and transform it to be a dictionary with database name as key - ground_truths_by_db = _get_ground_truths_by_db( + ground_truths_by_db = get_ground_truths_by_db( destination_path / "spider_data/dev.json", "db_id" ) question_sql_pairs_by_db = defaultdict(list) - for database in databases: + for database in database_names: if ground_truths_info := ground_truths_by_db.get(database): for ground_truth in ground_truths_info: question_sql_pairs_by_db[database].append( @@ -212,6 +277,84 @@ def _get_ground_truths_by_db(path: Path, key: str): return question_sql_pairs_by_db +def build_mdl_by_db_using_bird(destination_path: Path): + def _get_database_infos(path: Path): + database_infos = {} + for folder in path.iterdir(): + if folder.is_dir(): + path_to_database_description = ( + path / folder.name / "database_description" + ) + if ( + path_to_database_description in folder.iterdir() + and path_to_database_description.is_dir() + ): + database_infos[folder.name] = {} + for file in path_to_database_description.iterdir(): + if file.is_file() and file.suffix == ".csv": + df = pd.read_csv( + file, encoding="ISO-8859-1", keep_default_na=False + ) + database_infos[folder.name][file.stem] = df.to_dict( + orient="records" + ) + + return database_infos + + database_names = get_database_names( + destination_path / "minidev/MINIDEV/dev_databases" + ) + database_infos = _get_database_infos( + destination_path / "minidev/MINIDEV/dev_databases" + ) + tables_by_db = get_tables_by_db( + destination_path / "minidev/MINIDEV/dev_tables.json", "db_id" + ) + + # build mdl for each database by checking the test_tables.json in spider_data + mdl_by_db = {} + for database in database_names: + if tables_info := tables_by_db.get(database): + mdl_by_db[database] = { + "catalog": database, + "schema": "main", + "models": build_mdl_models( + database, tables_info, database_infos.get(database, {}) + ), + "relationships": build_mdl_relationships(tables_info), + "views": [], + "metrics": [], + } + + return mdl_by_db + + +def build_question_sql_pairs_by_db_using_bird(destination_path: Path): + database_names = get_database_names( + destination_path / "minidev/MINIDEV/dev_databases" + ) + + ground_truths_by_db = get_ground_truths_by_db( + destination_path / "minidev/MINIDEV/mini_dev_sqlite.json", "db_id" + ) + + question_sql_pairs_by_db = defaultdict(list) + for database in database_names: + if ground_truths_info := ground_truths_by_db.get(database): + for ground_truth in ground_truths_info: + question_sql_pairs_by_db[database].append( + { + "question": ground_truth["question"], + "sql": ground_truth["SQL"], + "question_id": ground_truth["question_id"], + "evidence": ground_truth["evidence"], + "difficulty": ground_truth["difficulty"], + } + ) + + return question_sql_pairs_by_db + + def get_mdls_and_question_sql_pairs_by_common_db(mdl_by_db, question_sql_pairs_by_db): common_dbs = set(mdl_by_db.keys()) & set(question_sql_pairs_by_db.keys()) @@ -222,25 +365,57 @@ def get_mdls_and_question_sql_pairs_by_common_db(mdl_by_db, question_sql_pairs_b if __name__ == "__main__": - print(f"Downloading Spider 1.0 data if unavailable in {SPIDER_DESTINATION_PATH}...") - download_spider_data(SPIDER_DESTINATION_PATH) + parser = argparse.ArgumentParser( + description="Prepare evaluation dataset for text-to-sql tasks." + ) + parser.add_argument( + "--dataset", + choices=["spider1.0", "bird"], + default="spider1.0", + help="Choose which dataset to prepare (spider1.0 or bird)", + ) + args = parser.parse_args() + + if args.dataset == "spider1.0": + destination_path = SPIDER_DESTINATION_PATH + print( + f"Downloading {args.dataset} data if unavailable in {destination_path}..." + ) + download_spider_data(destination_path) + elif args.dataset == "bird": + destination_path = BIRD_DESTINATION_PATH + print( + f"Downloading {args.dataset} data if unavailable in {destination_path}..." + ) + download_bird_data(destination_path) - print("Building mdl and question sql pairs using Spider 1.0 data...") + print(f"Building mdl and question sql pairs using {args.dataset} data...") # get mdl_by_db and question_sql_pairs_by_db whose dbs are present in both dictionaries - mdl_and_ground_truths_by_db = get_mdls_and_question_sql_pairs_by_common_db( - build_mdl_by_db(SPIDER_DESTINATION_PATH), - build_question_sql_pairs_by_db(SPIDER_DESTINATION_PATH), - ) + if args.dataset == "spider1.0": + mdl_and_ground_truths_by_db = get_mdls_and_question_sql_pairs_by_common_db( + build_mdl_by_db_using_spider(destination_path), + build_question_sql_pairs_by_db_using_spider(destination_path), + ) + elif args.dataset == "bird": + mdl_and_ground_truths_by_db = get_mdls_and_question_sql_pairs_by_common_db( + build_mdl_by_db_using_bird(destination_path), + build_question_sql_pairs_by_db_using_bird(destination_path), + ) print("Creating eval dataset...") # create duckdb connection in wren engine # https://duckdb.org/docs/guides/database_integration/sqlite.html prepare_duckdb_session_sql(WREN_ENGINE_API_URL) + questions_size = 0 + if args.dataset == "spider1.0": + duckdb_init_path = "etc/spider1.0/database" + elif args.dataset == "bird": + duckdb_init_path = "etc/bird/minidev/MINIDEV/dev_databases" for db, values in sorted(mdl_and_ground_truths_by_db.items()): candidate_eval_dataset = [] print(f"Database: {db}") - prepare_duckdb_init_sql(WREN_ENGINE_API_URL, db) + prepare_duckdb_init_sql(WREN_ENGINE_API_URL, db, duckdb_init_path) for i, ground_truth in enumerate(values["ground_truth"]): context = asyncio.run( @@ -278,9 +453,12 @@ def get_mdls_and_question_sql_pairs_by_common_db(mdl_by_db, question_sql_pairs_b # save eval dataset if candidate_eval_dataset: - with open( - f"{EVAL_DATASET_DESTINATION_PATH}/spider_{db}_eval_dataset.toml", "w" - ) as f: + if args.dataset == "spider1.0": + file_name = f"spider_{db}_eval_dataset.toml" + elif args.dataset == "bird": + file_name = f"bird_{db}_eval_dataset.toml" + + with open(f"{EVAL_DATASET_DESTINATION_PATH}/{file_name}", "w") as f: f.write( get_eval_dataset_in_toml_string( values["mdl"], candidate_eval_dataset @@ -289,4 +467,7 @@ def get_mdls_and_question_sql_pairs_by_common_db(mdl_by_db, question_sql_pairs_b print( f"Successfully creating eval dataset of database {db}, which has {len(candidate_eval_dataset)} questions" ) + questions_size += len(candidate_eval_dataset) print() + + print(f"Total questions size: {questions_size}") diff --git a/wren-ai-service/eval/utils.py b/wren-ai-service/eval/utils.py index 13189a3f9..eeeec39f1 100644 --- a/wren-ai-service/eval/utils.py +++ b/wren-ai-service/eval/utils.py @@ -164,7 +164,9 @@ async def _get_sql_analysis( ) -> List[dict]: sql = sql.rstrip(";") if sql.endswith(";") else sql quoted_sql, no_error = add_quotes(sql) - assert no_error, f"Error in quoting SQL: {sql}" + if not no_error: + print(f"Error in quoting SQL: {sql}") + quoted_sql = sql manifest_str = base64.b64encode(orjson.dumps(mdl_json)).decode() @@ -211,7 +213,9 @@ def trace_metadata( } -def engine_config(mdl: dict, pipe_components: dict[str, Any] = {}) -> dict: +def engine_config( + mdl: dict, pipe_components: dict[str, Any] = {}, path: str = "" +) -> dict: engine = pipe_components.get("sql_generation", {}).get("engine") if engine is None: @@ -222,7 +226,7 @@ def engine_config(mdl: dict, pipe_components: dict[str, Any] = {}) -> dict: if isinstance(engine, WrenEngine): print("datasource is duckdb") prepare_duckdb_session_sql(engine._endpoint) - prepare_duckdb_init_sql(engine._endpoint, mdl["catalog"]) + prepare_duckdb_init_sql(engine._endpoint, mdl["catalog"], path) return { "mdl_json": mdl, "api_endpoint": engine._endpoint, @@ -540,10 +544,8 @@ def prepare_duckdb_session_sql(api_endpoint: str): assert response.status_code == 200, response.text -def prepare_duckdb_init_sql(api_endpoint: str, db: str): - init_sql = ( - f"ATTACH 'etc/spider1.0/database/{db}/{db}.sqlite' AS {db} (TYPE sqlite);" - ) +def prepare_duckdb_init_sql(api_endpoint: str, db: str, path: str): + init_sql = f"ATTACH '{path}/{db}/{db}.sqlite' AS {db} (TYPE sqlite);" response = requests.put( f"{api_endpoint}/v1/data-source/duckdb/settings/init-sql",