diff --git a/lazyllm/docs/tools.py b/lazyllm/docs/tools.py index 56d4cd15..9bf87c10 100644 --- a/lazyllm/docs/tools.py +++ b/lazyllm/docs/tools.py @@ -1439,62 +1439,7 @@ """\ SqlManager是与数据库进行交互的专用工具。它提供了连接数据库,设置、创建、检查数据表,插入数据,执行查询的方法。 -Arguments: - db_type (str): 目前仅支持"PostgreSQL",后续会增加"MySQL", "MS SQL" - user (str): username - password (str): password - host (str): 主机名或IP - port (int): 端口号 - db_name (str): 数据仓库名 - tables_info_dict (dict): 数据表的描述 - options_str (str): k1=v1&k2=v2形式表示的选项设置 -""", -) - -add_english_doc( - "SqlManager", - """\ -SqlManager is a specialized tool for interacting with databases. -It provides methods for creating tables, executing queries, and performing updates on databases. - -Arguments: - db_type (str): Currently only "PostgreSQL" is supported, with "MySQL" and "MS SQL" to be added later. - user (str): Username for connection - password (str): Password for connection - host (str): Hostname or IP - port (int): Port number - db_name (str): Name of the database - tables_info_dict (dict): Description of the data tables - options_str (str): Options represented in the format k1=v1&k2=v2 -""", -) - -add_example( - "SqlManager", - """\ - >>> from lazyllm.tools import SqlManager - >>> import uuid - >>> # !!!NOTE!!!: COPY class SqlEgsData definition from tests/charge_tests/utils.py then Paste here. - >>> db_filepath = "personal.db" - >>> with open(db_filepath, "w") as _: - pass - >>> sql_manager = SQLiteManger(filepath, SqlEgsData.TEST_TABLES_INFO) - >>> # Altert: If using online database, ask administrator about those value: db_type, username, password, host, port, database - >>> # sql_manager = SqlManager(db_type, username, password, host, port, database, SqlEgsData.TEST_TABLES_INFO) - >>> - >>> for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - ... sql_manager.execute_sql_update(insert_script) - >>> str_results = sql_manager.get_query_result_in_json(SqlEgsData.TEST_QUERY_SCRIPTS) - >>> print(str_results) -""", -) - -add_chinese_doc( - "SqlManager.reset_table_info_dict", - """\ -根据描述表结构的字典设置SqlManager所使用的数据表。注意:若表在数据库中不存在将会自动创建,若存在则会校验所有字段的一致性。 -字典格式关键字示例如下。 - +tables_info_dict字典格式关键字示例如下。 字典中有3个关键字为可选项:表及列的comment默认为空, is_primary_key默认为False但至少应有一列为True, nullable默认为True {"tables": [ @@ -1518,16 +1463,26 @@ } ] } + +Arguments: + db_type (str): 目前仅支持"PostgreSQL",后续会增加"MySQL", "MS SQL" + user (str): username + password (str): password + host (str): 主机名或IP + port (int): 端口号 + db_name (str): 数据仓库名 + tables_info_dict (dict): 数据表的描述 + options_str (str): k1=v1&k2=v2形式表示的选项设置 """, ) add_english_doc( - "SqlManager.reset_table_info_dict", + "SqlManager", """\ -Set the data tables used by SqlManager according to the dictionary describing the table structure. -Note that if the table does not exist in the database, it will be automatically created, and if it exists, all field consistencies will be checked. -The dictionary format keyword example is as follows. +SqlManager is a specialized tool for interacting with databases. +It provides methods for creating tables, executing queries, and performing updates on databases. +The dictionary format of tables_info_dict is as follows. There are three optional keywords in the dictionary: "comment" for the table and columns defaults to empty, "is_primary_key" defaults to False, but at least one column should be True, and "nullable" defaults to True. {"tables": @@ -1552,13 +1507,43 @@ } ] } + +Arguments: + db_type (str): Currently only "PostgreSQL" is supported, with "MySQL" and "MS SQL" to be added later. + user (str): Username for connection + password (str): Password for connection + host (str): Hostname or IP + port (int): Port number + db_name (str): Name of the database + tables_info_dict (dict): Description of the data tables + options_str (str): Options represented in the format k1=v1&k2=v2 +""", +) + +add_example( + "SqlManager", + """\ + >>> from lazyllm.tools import SqlManager + >>> import uuid + >>> # !!!NOTE!!!: COPY class SqlEgsData definition from tests/charge_tests/utils.py then Paste here. + >>> db_filepath = "personal.db" + >>> with open(db_filepath, "w") as _: + pass + >>> sql_manager = SQLiteManger(filepath, SqlEgsData.TEST_TABLES_INFO) + >>> # Altert: If using online database, ask administrator about those value: db_type, username, password, host, port, database + >>> # sql_manager = SqlManager(db_type, username, password, host, port, database, SqlEgsData.TEST_TABLES_INFO) + >>> + >>> for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: + ... sql_manager.execute_sql_update(insert_script) + >>> str_results = sql_manager.get_query_result_in_json(SqlEgsData.TEST_QUERY_SCRIPTS) + >>> print(str_results) """, ) add_chinese_doc( - "SqlManagerBase", + "SqlBase", """\ -SqlManagerBase是与数据库进行交互的专用工具。它提供了连接数据库,设置、创建、检查数据表,插入数据,执行查询的方法。 +SqlBase是与数据库进行交互的专用工具。它提供了连接数据库,设置、创建、检查数据表,插入数据,执行查询的方法。 Arguments: db_type (str): 目前仅支持"PostgreSQL",后续会增加"MySQL", "MS SQL" @@ -1573,9 +1558,9 @@ ) add_english_doc( - "SqlManagerBase", + "SqlBase", """\ -SqlManagerBase is a specialized tool for interacting with databases. +SqlBase is a specialized tool for interacting with databases. It provides methods for creating tables, executing queries, and performing updates on databases. Arguments: @@ -1591,9 +1576,9 @@ ) add_chinese_doc( - "SqlManagerBase.check_connection", + "SqlBase.check_connection", """\ -检查当前SqlManagerBase的连接状态。 +检查当前SqlBase的连接状态。 **Returns:**\n - bool: 连接成功(True), 连接失败(False) @@ -1602,7 +1587,7 @@ ) add_english_doc( - "SqlManagerBase.check_connection", + "SqlBase.check_connection", """\ Check the current connection status of the SqlManagerBase. @@ -1613,33 +1598,19 @@ ) add_chinese_doc( - "SqlManagerBase.execute_to_json", + "SqlBase.execute_to_json", """\ 执行SQL查询并返回JSON格式的结果。 """, ) add_english_doc( - "SqlManagerBase.execute_to_json", + "SqlBase.execute_to_json", """\ Executes a SQL query and returns the result in JSON format. """, ) -add_chinese_doc( - "SqlManagerBase.execute", - """\ -在SQLite数据库上执行SQL插入或更新脚本。 -""", -) - -add_english_doc( - "SqlManagerBase.execute", - """\ -Execute insert or update script. -""", -) - add_chinese_doc( "SqlCall", """\ diff --git a/lazyllm/tools/__init__.py b/lazyllm/tools/__init__.py index 65adb436..3f6621e0 100644 --- a/lazyllm/tools/__init__.py +++ b/lazyllm/tools/__init__.py @@ -10,7 +10,7 @@ ReWOOAgent, ) from .classifier import IntentClassifier -from .sql import SqlManagerBase, SQLiteManger, SqlManager, MongoDBManager, DBResult, DBStatus +from .sql import SqlBase, SqlManager, MongoDBManager, DBResult, DBStatus from .sql_call import SqlCall from .tools.http_tool import HttpTool @@ -29,8 +29,7 @@ "ReWOOAgent", "IntentClassifier", "SentenceSplitter", - "SqlManagerBase", - "SQLiteManger", + "SqlBase", "SqlManager", "MongoDBManager", "DBResult", diff --git a/lazyllm/tools/sql/__init__.py b/lazyllm/tools/sql/__init__.py index 577ab5e9..ceafa955 100644 --- a/lazyllm/tools/sql/__init__.py +++ b/lazyllm/tools/sql/__init__.py @@ -1,5 +1,5 @@ -from .sql_manager import SqlManager, SqlManagerBase, SQLiteManger +from .sql_manager import SqlManager, SqlBase from .mongodb_manager import MongoDBManager from .db_manager import DBManager, DBResult, DBStatus -__all__ = ["DBManager", "SqlManagerBase", "SQLiteManger", "SqlManager", "MongoDBManager", "DBResult", "DBStatus"] +__all__ = ["DBManager", "SqlBase", "SqlManager", "MongoDBManager", "DBResult", "DBStatus"] diff --git a/lazyllm/tools/sql/db_manager.py b/lazyllm/tools/sql/db_manager.py index 6c2294df..ca811e28 100644 --- a/lazyllm/tools/sql/db_manager.py +++ b/lazyllm/tools/sql/db_manager.py @@ -1,8 +1,7 @@ from enum import Enum, unique -from typing import List, Union, overload +from typing import List, Union from pydantic import BaseModel from abc import ABC, abstractmethod -from urllib.parse import quote_plus @unique @@ -16,55 +15,43 @@ class DBResult(BaseModel): detail: str = "Success" result: Union[List, None] = None - class DBManager(ABC): DB_TYPE_SUPPORTED = set(["postgresql", "mysql", "mssql", "sqlite", "mongodb"]) - DB_DRIVER_MAP = {"mysql": "pymysql"} - def __init__( - self, - db_type: str, - user: str, - password: str, - host: str, - port: int, - db_name: str, - options_str: str = "", - ) -> None: - password = quote_plus(password) - self.status = DBStatus.SUCCESS - self.detail = "" + def __init__(self, db_type: str): db_type = db_type.lower() - db_result = self.reset_engine(db_type, user, password, host, port, db_name, options_str) - if db_result.status != DBStatus.SUCCESS: - raise ValueError(db_result.detail) - - @overload - def reset_engine(self, db_type, user, password, host, port, db_name, options_str) -> DBResult: - pass + if db_type not in self.DB_TYPE_SUPPORTED: + raise ValueError(f"{db_type} not supported") + self._db_type = db_type + self._desc = None @abstractmethod - def execute_to_json(self, statement): + def execute_to_json(self, statement) -> str: pass @property - def db_type(self): + def db_type(self) -> str: return self._db_type @property - def desc(self): - return self._desc - - def _is_str_or_nested_dict(self, value): - if isinstance(value, str): - return True - elif isinstance(value, dict): - return all(self._is_str_or_nested_dict(v) for v in value.values()) - return False - - def _validate_desc(self, d): - return isinstance(d, dict) and all(self._is_str_or_nested_dict(v) for v in d.values()) - - def _serialize_uncommon_type(self, obj): + @abstractmethod + def desc(self) -> str: pass + + @staticmethod + def _is_dict_all_str(d): + if not isinstance(d, dict): + return False + for key, value in d.items(): + if not isinstance(key, str): + return False + if isinstance(value, dict): + if not DBManager._is_dict_all_str(value): + return False + elif not isinstance(value, str): + return False + return True + + @staticmethod + def _serialize_uncommon_type(obj): if not isinstance(obj, int, str, float, bool, tuple, list, dict): return str(obj) diff --git a/lazyllm/tools/sql/mongodb_manager.py b/lazyllm/tools/sql/mongodb_manager.py index 2603d798..d7419fff 100644 --- a/lazyllm/tools/sql/mongodb_manager.py +++ b/lazyllm/tools/sql/mongodb_manager.py @@ -1,9 +1,10 @@ import json import pydantic from lazyllm.thirdparty import pymongo +from urllib.parse import quote_plus +from contextlib import contextmanager from .db_manager import DBManager, DBStatus, DBResult - class CollectionDesc(pydantic.BaseModel): summary: str = "" schema_type: dict @@ -11,124 +12,93 @@ class CollectionDesc(pydantic.BaseModel): class MongoDBManager(DBManager): - def __init__(self, user, password, host, port, db_name, collection_name, options_str=""): - result = self.reset_client(user, password, host, port, db_name, collection_name, options_str) - self.status, self.detail = result.status, result.detail - if self.status != DBStatus.SUCCESS: - raise ValueError(self.detail) - - def reset_client(self, user, password, host, port, db_name, collection_name, options_str="") -> DBResult: - self._db_type = "mongodb" - self.status = DBStatus.SUCCESS - self.detail = "" - conn_url = f"{self._db_type}://{user}:{password}@{host}:{port}/" - self._conn_url = conn_url - self._db_name = db_name - self._collection_name = collection_name - if options_str: - self._extra_fields = { - key: value for key_value in options_str.split("&") for key, value in (key_value.split("="),) - } - else: - self._extra_fields = {} - self._client = pymongo.MongoClient(self._conn_url) - result = self.check_connection() - self._collection = self._client[self._db_name][self._collection_name] - self._desc = {} - if result.status != DBStatus.SUCCESS: - return result + MAX_TIMEOUT_MS = 5000 + + def __init__(self, user: str, password: str, host: str, port: int, db_name: str, collection_name: str, + options_str="", collection_desc_dict: dict = None): + super().__init__(db_type="mongodb") + self.user = user + self.password = password + self.host = host + self.port = port + self.db_name = db_name + self.collection_name = collection_name + self.options_str = options_str + self._collection = None + self.collection_desc_dict = None + self._conn_url = self._gen_conn_url() + + def _gen_conn_url(self) -> str: + password = quote_plus(self.password) + conn_url = (f"{self._db_type}://{self.user}:{password}@{self.host}:{self.port}/" + f"{('?' + self.options_str) if self.options_str else ''}") + return conn_url + + @contextmanager + def get_client(self): """ - if db_name not in self.client.list_database_names(): - return DBResult(status=DBStatus.FAIL, detail=f"Database {db_name} not found") - if collection_name not in self.client[db_name].list_collection_names(): - return DBResult(status=DBStatus.FAIL, detail=f"Collection {collection_name} not found") + Get the client object with the context manager. + Use client to manage database. + + Usage: + + >>> with mongodb_manager.get_client() as client: + >>> all_dbs = client.list_database_names() """ - return DBResult() + client = pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS) + try: + yield client + finally: + client.close() + + @property + def desc(self): + if self._desc is None: + self.set_desc(schema_desc_dict=self.collection_desc_dict) + return self._desc + + def set_desc(self, schema_desc_dict: dict): + self.collection_desc_dict = schema_desc_dict + if schema_desc_dict is None: + with self.get_client() as client: + egs_one = client[self.db_name][self.collection_name].find_one() + if egs_one is not None: + self._desc = "Collection Example:\n" + self._desc += json.dumps(egs_one, ensure_ascii=False, indent=4) + else: + self._desc = "" + try: + collection_desc = CollectionDesc.model_validate(schema_desc_dict) + except pydantic.ValidationError as e: + raise ValueError(f"Validate input schema_desc_dict failed: {str(e)}") + if not self._is_dict_all_str(collection_desc.schema_type): + raise ValueError("schema_type shouble be str or nested str dict") + if not self._is_dict_all_str(collection_desc.schema_desc): + raise ValueError("schema_desc shouble be str or nested str dict") + if collection_desc.summary: + self._desc += f"Collection summary: {collection_desc.summary}\n" + self._desc += "Collection schema:\n" + self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4) + self._desc += "Collection schema description:\n" + self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4) def check_connection(self) -> DBResult: try: # check connection status - _ = self._client.server_info() + with pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS) as client: + _ = client.server_info() return DBResult() except Exception as e: return DBResult(status=DBStatus.FAIL, detail=str(e)) - def get_all_collections(self): - return DBResult(result=self._client[self._db_name].list_collection_names()) - - def drop_database(self) -> DBResult: - if self.status != DBStatus.SUCCESS: - return DBResult(status=self.status, detail=self.detail, result=None) - self._client.drop_database(self._db_name) - return DBResult() - - def drop_collection(self, collection_name) -> DBResult: - db = self._client[self._db_name] - db[collection_name].drop() - return DBResult() - - def insert(self, statement): - if isinstance(statement, dict): - self._collection.insert_one(statement) - elif isinstance(statement, list): - self._collection.insert_many(statement) - else: - return DBResult(status=DBStatus.FAIL, detail=f"statement type {type(statement)} not supported", result=None) - return DBResult() - - def update(self, filter: dict, value: dict, is_many: bool = True): - if is_many: - self._collection.update_many(filter, value) - else: - self._collection.update_one(filter, value) - return DBResult() - - def delete(self, filter: dict, is_many: bool = True): - if is_many: - self._collection.delete_many(filter) - else: - self._collection.delete_one(filter) - - def select(self, query, projection: dict[str, bool] = None, limit: int = None): - if limit is None: - result = self._collection.find(query, projection) - else: - result = self._collection.find(query, projection).limit(limit) - return DBResult(result=list(result)) - - def execute(self, statement): + def execute_to_json(self, statement) -> str: + str_result = "" try: pipeline_list = json.loads(statement) - result = self._collection.aggregate(pipeline_list) - return DBResult(result=list(result)) + with self.get_client() as client: + collection = client[self.db_name][self.collection_name] + result = list(collection.aggregate(pipeline_list)) + str_result = json.dumps(result, ensure_ascii=False, default=self._serialize_uncommon_type) except Exception as e: - return DBResult(status=DBStatus.FAIL, detail=str(e)) - - def execute_to_json(self, statement) -> str: - dbresult = self.execute(statement) - if dbresult.status != DBStatus.SUCCESS: - self.status, self.detail = dbresult.status, dbresult.detail - return "" - str_result = json.dumps(dbresult.result, ensure_ascii=False, default=self._serialize_uncommon_type) + str_result = f"MongoDB ERROR: {str(e)}" return str_result - - @property - def desc(self): - return self._desc - - def set_desc(self, schema_and_desc: dict) -> DBResult: - self._desc = "" - try: - collection_desc = CollectionDesc.model_validate(schema_and_desc) - except pydantic.ValidationError as e: - return DBResult(status=DBStatus.FAIL, detail=str(e)) - if not self._validate_desc(collection_desc.schema_type) or not self._validate_desc(collection_desc.schema_desc): - err_msg = "key and value in desc shoule be str or nested str dict" - return DBResult(status=DBStatus.FAIL, detail=err_msg) - if collection_desc.summary: - self._desc += f"Collection summary: {collection_desc.summary}\n" - self._desc += "Collection schema:\n" - self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4) - self._desc += "Collection schema description:\n" - self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4) - return DBResult() diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py index ad3c4806..699da834 100644 --- a/lazyllm/tools/sql/sql_manager.py +++ b/lazyllm/tools/sql/sql_manager.py @@ -1,64 +1,55 @@ import json -from typing import Union +from typing import Union, List, Type import sqlalchemy from sqlalchemy.exc import SQLAlchemyError, OperationalError from sqlalchemy.orm import DeclarativeBase, DeclarativeMeta, sessionmaker +from sqlalchemy.ext.automap import automap_base import pydantic from contextlib import contextmanager from .db_manager import DBManager, DBStatus, DBResult +from urllib.parse import quote_plus +import re class TableBase(DeclarativeBase): pass -class SqlManagerBase(DBManager): - - def __init__(self, db_type, user, password, host, port, db_name, options_str="", set_default_des=True): - self._set_default_desc = set_default_des - super().__init__(db_type, user, password, host, port, db_name, options_str) - - def reset_engine( - self, - db_type: str, - user: str, - password: str, - host: str, - port: int, - db_name: str, - options_str: str = "", - ): - self._db_type = db_type - if db_type not in self.DB_TYPE_SUPPORTED: - return DBResult(status=DBStatus.FAIL, detail=f"{db_type} not supported") - if db_type in self.DB_DRIVER_MAP: - conn_url = f"{db_type}+{self.DB_DRIVER_MAP[db_type]}://{user}:{password}@{host}:{port}/{db_name}" +class SqlBase(DBManager): + DB_DRIVER_MAP = {"mysql": "pymysql"} + + def __init__(self, db_type: str, user: str, password: str, host: str, port: int, db_name: str, options_str=""): + super().__init__(db_type) + self.user = user + self.password = password + self.host = host + self.port = port + self.db_name = db_name + self.options_str = options_str + self.tables_desc_dict = {} + self._engine = None + self._llm_visible_tables = None + self._metadata = sqlalchemy.MetaData() + + def _gen_conn_url(self) -> str: + if self._db_type == "sqlite": + conn_url = f"sqlite:///{self.db_name}{('?' + self.options_str) if self.options_str else ''}" else: - conn_url = f"{db_type}://{user}:{password}@{host}:{port}/{db_name}" - self._conn_url = conn_url + driver = self.DB_DRIVER_MAP.get(self._db_type, "") + password = quote_plus(self.password) + conn_url = (f"{self._db_type}{('+' + driver) if driver else ''}://{self.user}:{password}@{self.host}" + f":{self.port}/{self.db_name}{('?' + self.options_str) if self.options_str else ''}") + return conn_url - self._engine = sqlalchemy.create_engine(self._conn_url) - self._Session = sessionmaker(bind=self._engine) - self._desc = "" - extra_fields = {} - if options_str: - extra_fields = { - key: value for key_value in options_str.split("&") for key, value in (key_value.split("="),) - } - self._extra_fields = extra_fields - db_result = self.check_connection() - if db_result.status != DBStatus.SUCCESS: - return db_result - db_result = self.get_all_tables() - if db_result.status != DBStatus.SUCCESS: - return db_result - self._visible_tables = db_result.result - if self._set_default_desc: - return self.set_desc() - return db_result + @property + def engine(self): + if self._engine is None: + self._engine = sqlalchemy.create_engine(self._gen_conn_url()) + return self._engine @contextmanager def get_session(self): - session = self._Session() + _Session = sessionmaker(bind=self.engine) + session = _Session() try: yield session session.commit() @@ -68,227 +59,136 @@ def get_session(self): finally: session.close() - def get_all_tables(self) -> DBResult: - inspector = sqlalchemy.inspect(self._engine) - table_names = inspector.get_table_names(schema=self._extra_fields.get("schema", None)) - if self.status != DBStatus.SUCCESS: - return DBResult(status=self.status, detail=self.detail, result=None) - return DBResult(result=table_names) - def check_connection(self) -> DBResult: try: - with self._engine.connect() as _: + with self.engine.connect() as _: return DBResult() except SQLAlchemyError as e: return DBResult(status=DBStatus.FAIL, detail=str(e)) @property - def visible_tables(self): - return self._visible_tables - - def set_visible_tables(self, tables: list[str]) -> DBResult: - db_result = self.get_all_tables() - if db_result.status != DBStatus.SUCCESS: - return db_result - all_tables_in_db = set(db_result.result) - visible_tables = [] - failed_tables = [] - for ele in tables: - if ele in all_tables_in_db: - visible_tables.append(ele) - else: - failed_tables.append(ele) - if len(tables) != len(visible_tables): - db_result = DBResult(status=DBStatus.FAIL, detail=f"{failed_tables} missing in database") - else: - db_result = DBResult() - self._visible_tables = visible_tables - return db_result - - def _get_table_columns(self, table_name: str): - inspector = sqlalchemy.inspect(self._engine) - columns = inspector.get_columns(table_name, schema=self._extra_fields.get("schema", None)) - return columns + def desc(self) -> str: + if self._desc is None: + self.set_desc(tables_desc_dict={}) + return self._desc - def set_desc(self, tables_desc: dict = {}) -> DBResult: + def set_desc(self, tables_desc_dict: dict = {}): self._desc = "" - if not isinstance(tables_desc, dict): - return DBResult(status=DBStatus.FAIL, detail=f"desc type {type(tables_desc)} not supported") - if len(tables_desc) == 0: - return DBResult(status=DBStatus.FAIL, detail="Empty desc") - if len(self.visible_tables) == 0: - return DBResult() + if not isinstance(tables_desc_dict, dict): + raise ValueError(f"desc type {type(tables_desc_dict)} not supported") + self.tables_desc_dict = tables_desc_dict + if len(self.llm_visible_tables) == 0: + return + # Generate desc according to table schema and comment self._desc = "The tables description is as follows\n```\n" - for table_name in self.visible_tables: + for table_name in self.llm_visible_tables: self._desc += f"Table {table_name}\n(\n" - table_columns = self._get_table_columns(table_name) + table_columns = self.get_table_orm_class(table_name).columns for i, column in enumerate(table_columns): self._desc += f" {column['name']} {column['type']}" if i != len(table_columns) - 1: self._desc += "," self._desc += "\n" self._desc += ");\n" - if table_name in tables_desc: - self._desc += tables_desc[table_name] + "\n\n" + if table_name in tables_desc_dict: + self._desc += tables_desc_dict[table_name] + "\n\n" self._desc += "```\n" - return DBResult() @property - def desc(self) -> str: - return self._desc - - def execute(self, statement) -> DBResult: - if isinstance(statement, str): - statement = sqlalchemy.text(statement) - if isinstance( - statement, - (sqlalchemy.TextClause, sqlalchemy.Select, sqlalchemy.Insert, sqlalchemy.Update, sqlalchemy.Delete), - ): - status = DBStatus.SUCCESS - detail = "" - result = None - try: - with self._engine.connect() as conn: - cursor_result = conn.execute(statement) - conn.commit() - if cursor_result.returns_rows: - columns = list(cursor_result.keys()) - result = [dict(zip(columns, row)) for row in cursor_result] - except OperationalError as e: - status = DBStatus.FAIL - detail = f"ERROR: {str(e)}" - finally: - if "conn" in locals(): - conn.close() - return DBResult(status=status, detail=detail, result=result) - else: - return DBResult(status=DBStatus.FAIL, detail="statement type not supported") - - def execute_to_json(self, statement) -> str: - dbresult = self.execute(statement) - if dbresult.status != DBStatus.SUCCESS: - self.status, self.detail = dbresult.status, dbresult.detail - return "" - if dbresult.result is None: - return "" - str_result = json.dumps(dbresult.result, ensure_ascii=False, default=self._serialize_uncommon_type) + def llm_visible_tables(self): + if self._llm_visible_tables is None: + self._llm_visible_tables = self.get_all_tables() + return self._llm_visible_tables + + @llm_visible_tables.setter + def llm_visible_tables(self, visible_tables: list): + all_tables = set(self.get_all_tables()) + for ele in visible_tables: + if ele not in all_tables: + raise ValueError(f"Table {ele} not found in database") + self._llm_visible_tables = visible_tables + self.set_desc(self.tables_desc_dict) + + def _refresh_metadata(self, only=None): + # refresh metadata in case of deleting/creating table in other session + self._metadata.clear() + self._metadata.reflect(bind=self.engine, only=only) + + def get_all_tables(self) -> list: + self._refresh_metadata() + return list(self._metadata.tables.keys()) + + def get_table_orm_class(self, table_name): + self._refresh_metadata(only=[table_name]) + Base = automap_base(metadata=self._metadata) + Base.prepare() + return getattr(Base.classes, table_name, None) + + def execute_commit(self, statement: str): + with self.get_session() as session: + session.execute(sqlalchemy.text(statement)) + + def execute_to_json(self, statement: str) -> str: + statement = re.sub(r"/\*.*?\*/", "", statement, flags=re.DOTALL).strip() + if not statement.upper().startswith("SELECT"): + return "Only select statement supported" + try: + result = [] + with self.get_session() as session: + cursor_result = session.execute(sqlalchemy.text(statement)) + columns = list(cursor_result.keys()) + result = [dict(zip(columns, row)) for row in cursor_result] + str_result = json.dumps(result, ensure_ascii=False, default=self._serialize_uncommon_type) + except Exception as e: + str_result = f"Execute SQL ERROR: {str(e)}" return str_result def _create_by_script(self, table: str) -> DBResult: status = DBStatus.SUCCESS detail = "Success" try: - with self._engine.connect() as conn: + with self.engine.connect() as conn: conn.execute(sqlalchemy.text(table)) conn.commit() except OperationalError as e: status = DBStatus.FAIL detail = f"ERROR: {str(e)}" - finally: - if "conn" in locals(): - conn.close() return DBResult(status=status, detail=detail) - def _match_exist_table(self, table: DeclarativeMeta) -> DBResult: - status = DBStatus.SUCCESS - detail = f"Table {table.__tablename__} already exists." - metadata = sqlalchemy.MetaData() - exist_table = sqlalchemy.Table(table.__tablename__, metadata, autoload_with=self._engine) - if len(table.__table__.columns) != len(exist_table.columns): - status = DBStatus.FAIL - detail += ( - f"\n Column number mismatch: {len(table.__table__.columns)} VS " f"{len(exist_table.columns)}(exists)" - ) - return DBResult(status=status, detail=detail) - for exist_column in exist_table.columns: - target_column = getattr(table, exist_column.name) - exist_type = type(exist_column.type) - target_type = type(target_column.type) - type_is_subclass = issubclass(exist_type, target_type) or issubclass(target_type, exist_type) - if target_type is not sqlalchemy.types.TypeEngine and not type_is_subclass: - detail += f"type mismatch {exist_type} vs {target_type}" - return DBResult(status=DBStatus.FAIL, detail=detail) - for attr in ["primary_key", "nullable"]: - if getattr(exist_column, attr) != getattr(target_column, attr): - detail += f"{attr} mismatch {getattr(exist_column, attr)} vs {getattr(target_column, attr)}" - return DBResult(status=DBStatus.FAIL, detail=detail) + def _create_by_api(self, table: Union[DeclarativeBase, DeclarativeMeta]) -> DBResult: + table.metadata.create_all(bind=self.engine, checkfirst=True) return DBResult() - def _create_by_api(self, table: DeclarativeMeta) -> DBResult: - try: - table.metadata.create_all(bind=self._engine) - return DBResult() - except Exception as e: - if "already exists" in str(e): - return self._match_exist_table(table) - - def create(self, table: Union[str, DeclarativeMeta]) -> DBResult: + def create_table(self, table: Union[str, Type[DeclarativeBase], Type[DeclarativeMeta]]) -> DBResult: status = DBStatus.SUCCESS detail = "Success" if isinstance(table, str): return self._create_by_script(table) - elif isinstance(table, DeclarativeMeta): + # Support DeclarativeMeta created by declarative_base() which is deprecated since: 2.0 + elif issubclass(table, DeclarativeBase) or isinstance(table, DeclarativeMeta): return self._create_by_api(table) else: status = DBStatus.FAIL - detail += "\n Unsupported Type: {table}" + detail += f"Failed: Unsupported Type: {table}" return DBResult(status=status, detail=detail) - def drop(self, table) -> DBResult: - metadata = sqlalchemy.MetaData() + def drop_table(self, table: Union[str, Type[DeclarativeBase], Type[DeclarativeMeta]]) -> DBResult: + metadata = self._metadata if isinstance(table, str): tablename = table - elif isinstance(table, DeclarativeMeta): + elif issubclass(table, DeclarativeBase) or isinstance(table, DeclarativeMeta): tablename = table.__tablename__ else: return DBResult(status=DBStatus.FAIL, detail=f"{table} type unsupported") - Table = sqlalchemy.Table(tablename, metadata, autoload_with=self._engine) - Table.drop(self._engine, checkfirst=True) + Table = sqlalchemy.Table(tablename, metadata, autoload_with=self.engine) + Table.drop(self.engine, checkfirst=True) return DBResult() - def insert(self, statement) -> DBResult: - if isinstance(statement, (str, sqlalchemy.Insert)): - return self.execute(statement) - elif isinstance(statement, dict): - table_name = statement.get("table_name", None) - table_data = statement.get("table_data", []) - returning = statement.get("returning", []) - if not table_name: - return DBResult(status=DBStatus.FAIL, detail="No table_name found") - if not table_data: - return DBResult(status=DBStatus.FAIL, detail="No table_data found") - metadata = sqlalchemy.MetaData() - table = sqlalchemy.Table(table_name, metadata, autoload_with=self._engine) - if not returning: - statement = sqlalchemy.insert(table).values(table_data) - else: - return_columns = [sqlalchemy.column(ele) for ele in returning] - statement = (sqlalchemy.insert(table).values(table_data)).returning(*return_columns) - return self.execute(statement) - else: - return DBResult(status=DBStatus.FAIL, detail="statement type not supported") - - def update(self, statement) -> DBResult: - if isinstance(statement, (str, sqlalchemy.Update)): - return self.execute(statement) - else: - return DBResult(status=DBStatus.FAIL, detail="statement type not supported") - - def delete(self, statement) -> DBResult: - if isinstance(statement, (str, sqlalchemy.Delete)): - if isinstance(statement, str): - tmp = statement.rstrip() - if len(tmp.split()) == 1: - statement = f"DELETE FROM {tmp}" - return self.execute(statement) - else: - return DBResult(status=DBStatus.FAIL, detail="statement type not supported") - - def select(self, statement) -> DBResult: - if isinstance(statement, (str, sqlalchemy.Select)): - return self.execute(statement) - else: - return DBResult(status=DBStatus.FAIL, detail="statement type not supported") + def insert_values(self, table_name: str, vals: List[dict]): + # Refresh metadata in case of tables created by other api + TableCls = self.get_table_orm_class(table_name) + with self.get_session() as session: + session.bulk_insert_mappings(TableCls, vals) class ColumnInfo(pydantic.BaseModel): @@ -309,7 +209,7 @@ class TableInfo(pydantic.BaseModel): class TablesInfo(pydantic.BaseModel): tables: list[TableInfo] -class SqlManager(SqlManagerBase): +class SqlManager(SqlBase): PYTYPE_TO_SQL_MAP = { "integer": sqlalchemy.Integer, "string": sqlalchemy.Text, @@ -324,7 +224,6 @@ class SqlManager(SqlManagerBase): "list": sqlalchemy.ARRAY, "dict": sqlalchemy.JSON, "uuid": sqlalchemy.Uuid, - "any": sqlalchemy.types.TypeEngine, } def __init__( @@ -338,73 +237,41 @@ def __init__( tables_info_dict: dict, options_str: str = "", ) -> None: - self._tables_info_dict = tables_info_dict - super().__init__(db_type, user, password, host, port, db_name, options_str, set_default_des=False) - - def reset_engine(self, db_type, user, password, host, port, db_name, options_str): - super().reset_engine(db_type, user, password, host, port, db_name, options_str) - db_result = self.reset_table_info_dict(self._tables_info_dict) - self.status = db_result.status - self.detail = db_result.detail - if self.status != DBStatus.SUCCESS: - raise ValueError(self.detail) - return db_result - - def reset_table_info_dict(self, tables_info_dict: dict) -> DBResult: - self.status = DBStatus.SUCCESS - self.detail = "Success" - self._tables_info_dict = tables_info_dict + super().__init__(db_type, user, password, host, port, db_name, options_str) try: - tables_info = TablesInfo.model_validate(self._tables_info_dict) + self._tables_info = TablesInfo.model_validate(tables_info_dict) + self._visible_tables = [table_info.name for table_info in self._tables_info.tables] + # create table if not exist + self.create_tables_by_info(self._tables_info) + self.set_desc(self._tables_info) except pydantic.ValidationError as e: - self.status, self.detail = DBStatus.FAIL, str(e) - return DBResult(status=DBStatus.FAIL, detail=str(e)) - # Create or Check tables - created_tables = [] + raise ValueError(f"Validate tables_info_dict failed: {str(e)}") + + def create_tables_by_info(self, tables_info: TablesInfo): for table_info in tables_info.tables: TableClass = self._create_table_cls(table_info) - db_result = self.create(TableClass) - if db_result.status != DBStatus.SUCCESS: - # drop partial created table - for created_table in created_tables: - self.drop(created_table) - return db_result - created_tables.append(TableClass) - - db_result = self.set_visible_tables([ele.__tablename__ for ele in created_tables]) - if db_result.status != DBStatus.SUCCESS: - return db_result - return self.set_desc() - - def _create_table_cls(self, table_info: TableInfo) -> DeclarativeMeta: - attrs = {"__tablename__": table_info.name} - for column_info in table_info.columns: - column_type = column_info.data_type.lower() - is_nullable = column_info.nullable - column_name = column_info.name - is_primary = column_info.is_primary_key - real_type = self.PYTYPE_TO_SQL_MAP[column_type] - attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, primary_key=is_primary) - TableClass = type(table_info.name.capitalize(), (TableBase,), attrs) - return TableClass + self.create_table(TableClass) - def set_desc(self) -> DBResult: - self._desc = "" - try: - tables_info = TablesInfo.model_validate(self._tables_info_dict) - except pydantic.ValidationError as e: - self.status, self.detail = DBStatus.FAIL, str(e) - return DBResult(status=DBStatus.FAIL, detail=str(e)) + @property + def desc(self) -> str: + return self._desc + + @property + def tables_info(self) -> TablesInfo: + return self._tables_info + + def set_desc(self, tables_info: TablesInfo): self._desc = "The tables description is as follows\n```\n" for table_info in tables_info.tables: self._desc += f'Table "{table_info.name}"' if table_info.comment: self._desc += f' comment "{table_info.comment}"' self._desc += "\n(\n" - real_columns = self._get_table_columns(table_info.name) + TableCls = self.get_table_orm_class(table_info.name) + real_columns = TableCls.__table__.columns column_type_dict = {} for real_column in real_columns: - column_type_dict[real_column["name"]] = real_column["type"] + column_type_dict[real_column.name] = real_column.type for i, column_info in enumerate(table_info.columns): self._desc += f"{column_info.name} {column_type_dict[column_info.name]}" if column_info.comment: @@ -414,22 +281,26 @@ def set_desc(self) -> DBResult: self._desc += "\n" self._desc += ");\n" self._desc += "```\n" - return DBResult() + @property + def llm_visible_tables(self): + return self._llm_visible_tables -class SQLiteManger(SqlManager): - - def __init__(self, db_path: str, tables_info_dict: dict = {}): - result = self.reset_engine(db_path, tables_info_dict) - self.status, self.detail = result.status, result.detail - if self.status != DBStatus.SUCCESS: - raise ValueError(self.detail) + @llm_visible_tables.setter + def llm_visible_tables(self, visible_tables: list): + raise AttributeError("Cannot set attribute 'llm_visible_tables' in SqlManager") - def reset_engine(self, db_path: str, tables_info_dict: dict): - self._db_type = "sqlite" - self.status = DBStatus.SUCCESS - self.detail = "" - self._conn_url = f"sqlite:///{db_path}" - self._extra_fields = {} - self._engine = sqlalchemy.create_engine(self._conn_url) - return self.reset_table_info_dict(tables_info_dict) + def _create_table_cls(self, table_info: TableInfo) -> Type[DeclarativeBase]: + attrs = {"__tablename__": table_info.name, "__table_args__": {"extend_existing": True}, + "metadata": self._metadata} + for column_info in table_info.columns: + column_type = column_info.data_type.lower() + is_nullable = column_info.nullable + column_name = column_info.name + is_primary = column_info.is_primary_key + # Use text for unsupported column type + real_type = self.PYTYPE_TO_SQL_MAP.get(column_type, sqlalchemy.Text) + attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, primary_key=is_primary) + # When create dynamic class with same name, old version will be replaced + TableClass = type(table_info.name.capitalize(), (TableBase,), attrs) + return TableClass diff --git a/lazyllm/tools/sql/sql_tool.py b/lazyllm/tools/sql/sql_tool.py deleted file mode 100644 index 90f5c420..00000000 --- a/lazyllm/tools/sql/sql_tool.py +++ /dev/null @@ -1,390 +0,0 @@ -from lazyllm.module import ModuleBase -import lazyllm -from lazyllm.components import ChatPrompter -from lazyllm.tools.utils import chat_history_to_str -from lazyllm import pipeline, globals, bind, _0, switch -import json -from typing import List, Any, Dict, Union -import datetime -import re -import sqlalchemy -from sqlalchemy.exc import SQLAlchemyError, OperationalError -from sqlalchemy.orm import declarative_base -import pydantic -from urllib.parse import quote_plus - - -class ColumnInfo(pydantic.BaseModel): - name: str - data_type: str - comment: str = "" - # At least one column should be True - is_primary_key: bool = False - nullable: bool = True - - -class TableInfo(pydantic.BaseModel): - name: str - comment: str = "" - columns: list[ColumnInfo] - - -class TablesInfo(pydantic.BaseModel): - tables: list[TableInfo] - - -class SqlManager(ModuleBase): - DB_TYPE_SUPPORTED = set(["PostgreSQL", "MySQL", "MSSQL", "SQLite"]) - SUPPORTED_DATA_TYPES = { - "integer": sqlalchemy.Integer, - "string": sqlalchemy.String, - "text": sqlalchemy.Text, - "boolean": sqlalchemy.Boolean, - "float": sqlalchemy.Float, - } - - def __init__( - self, - db_type: str, - user: str, - password: str, - host: str, - port: int, - db_name: str, - tables_info_dict: dict, - options_str: str = "", - ) -> None: - super().__init__() - if db_type.lower() != "sqlite": - password = quote_plus(password) - conn_url = f"{db_type.lower()}://{user}:{password}@{host}:{port}/{db_name}" - self.reset_db(db_type, conn_url, tables_info_dict, options_str) - - def forward(self, sql_script: str) -> str: - return self.get_query_result_in_json(sql_script) - - def reset_tables(self, tables_info_dict: dict) -> tuple[bool, str]: - existing_tables = set(self.get_all_tables()) - try: - tables_info = TablesInfo.model_validate(tables_info_dict) - except pydantic.ValidationError as e: - lazyllm.LOG.warning(str(e)) - return False, str(e) - for table_info in tables_info.tables: - if table_info.name not in existing_tables: - # create table - cur_rt, cur_err_msg = self._create_table(table_info.model_dump()) - else: - # check table - cur_rt, cur_err_msg = self._check_columns_match(table_info.model_dump()) - if not cur_rt: - lazyllm.LOG.warning(f"cur_err_msg: {cur_err_msg}") - return cur_rt, cur_err_msg - rt, err_msg = self._set_tables_desc_prompt(tables_info_dict) - if not rt: - lazyllm.LOG.warning(err_msg) - return True, "Success" - - def reset_db(self, db_type: str, conn_url: str, tables_info_dict: dict, options_str=""): - assert db_type in self.DB_TYPE_SUPPORTED - extra_fields = {} - if options_str: - extra_fields = { - key: value for key_value in options_str.split("&") for key, value in (key_value.split("="),) - } - self.db_type = db_type - self.conn_url = conn_url - self.extra_fields = extra_fields - self.engine = sqlalchemy.create_engine(conn_url) - self.tables_prompt = "" - rt, err_msg = self.reset_tables(tables_info_dict) - if not rt: - self.err_msg = err_msg - self.err_code = 1001 - else: - self.err_code = 0 - - def get_tables_desc(self): - return self.tables_prompt - - def check_connection(self) -> tuple[bool, str]: - try: - with self.engine.connect() as _: - return True, "Success" - except SQLAlchemyError as e: - return False, str(e) - - def get_query_result_in_json(self, sql_script) -> str: - str_result = "" - try: - with self.engine.connect() as conn: - result = conn.execute(sqlalchemy.text(sql_script)) - columns = list(result.keys()) - result_dict = [dict(zip(columns, row)) for row in result] - str_result = json.dumps(result_dict, ensure_ascii=False) - except OperationalError as e: - str_result = f"ERROR: {str(e)}" - finally: - if "conn" in locals(): - conn.close() - return str_result - - def get_all_tables(self) -> list: - inspector = sqlalchemy.inspect(self.engine) - table_names = inspector.get_table_names(schema=self.extra_fields.get("schema", None)) - return table_names - - def execute_sql_update(self, sql_script): - rt, err_msg = True, "Success" - try: - with self.engine.connect() as conn: - conn.execute(sqlalchemy.text(sql_script)) - conn.commit() - except OperationalError as e: - lazyllm.LOG.warning(f"sql error: {str(e)}") - rt, err_msg = False, str(e) - finally: - if "conn" in locals(): - conn.close() - return rt, err_msg - - def _get_table_columns(self, table_name: str): - inspector = sqlalchemy.inspect(self.engine) - columns = inspector.get_columns(table_name, schema=self.extra_fields.get("schema", None)) - return columns - - def _create_table(self, table_info_dict: dict) -> tuple[bool, str]: - rt, err_msg = True, "Success" - try: - table_info = TableInfo.model_validate(table_info_dict) - except pydantic.ValidationError as e: - return False, str(e) - try: - with self.engine.connect() as conn: - Base = declarative_base() - # build table dynamically - attrs = {"__tablename__": table_info.name} - for column_info in table_info.columns: - column_type = column_info.data_type.lower() - is_nullable = column_info.nullable - column_name = column_info.name - is_primary = column_info.is_primary_key - if column_type not in self.SUPPORTED_DATA_TYPES: - return False, f"Unsupported column type: {column_type}" - real_type = self.SUPPORTED_DATA_TYPES[column_type] - attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, primary_key=is_primary) - TableClass = type(table_info.name.capitalize(), (Base,), attrs) - Base.metadata.create_all(self.engine) - except OperationalError as e: - rt, err_msg = False, f"ERROR: {str(e)}" - finally: - if "conn" in locals(): - conn.close() - return rt, err_msg - - def _delete_rows_by_name(self, table_name): - metadata = sqlalchemy.MetaData() - metadata.reflect(bind=self.engine) - rt, err_msg = True, "Success" - try: - with self.engine.connect() as conn: - table = sqlalchemy.Table(table_name, metadata, autoload_with=self.engine) - delete = table.delete() - conn.execute(delete) - conn.commit() - except SQLAlchemyError as e: - rt, err_msg = False, str(e) - return rt, err_msg - - def _drop_table_by_name(self, table_name): - metadata = sqlalchemy.MetaData() - metadata.reflect(bind=self.engine) - rt, err_msg = True, "Success" - try: - table = sqlalchemy.Table(table_name, metadata, autoload_with=self.engine) - table.drop(bind=self.engine, checkfirst=True) - except SQLAlchemyError as e: - lazyllm.LOG.warning("GET SQLAlchemyError") - rt, err_msg = False, str(e) - return rt, err_msg - - def _check_columns_match(self, table_info_dict: dict): - try: - table_info = TableInfo.model_validate(table_info_dict) - except pydantic.ValidationError as e: - return False, str(e) - real_columns = self._get_table_columns(table_info.name) - tmp_dict = {} - for real_column in real_columns: - tmp_dict[real_column["name"]] = (real_column["type"], real_column["nullable"]) - for column_info in table_info.columns: - if column_info.name not in tmp_dict: - return False, f"Table {table_info.name} exists but column {column_info.name} does not." - real_column = tmp_dict[column_info.name] - column_type = column_info.data_type.lower() - if column_type not in self.SUPPORTED_DATA_TYPES: - return False, f"Unsupported column type: {column_type}" - # 1. check data type - # string type sometimes changes to other type (such as varchar) - real_type_cls = real_column[0].__class__ - if column_type != real_type_cls.__name__.lower() and not issubclass( - real_type_cls, self.SUPPORTED_DATA_TYPES[column_type] - ): - return ( - False, - f"Table {table_info.name} exists but column {column_info.name} data_type mismatch" - f": {column_info.data_type} vs {real_column[0].__class__.__name__}", - ) - # 2. check nullable - if column_info.nullable != real_column[1]: - return False, f"Table {table_info.name} exists but column {column_info.name} nullable mismatch" - if len(tmp_dict) > len(table_info.columns): - return ( - False, - f"Table {table_info.name} exists but has more columns. {len(tmp_dict)} vs {len(table_info.columns)}", - ) - return True, "Match" - - def _set_tables_desc_prompt(self, tables_info_dict: dict) -> str: - try: - tables_info = TablesInfo.model_validate(tables_info_dict) - except pydantic.ValidationError as e: - return False, str(e) - self.tables_prompt = "The tables description is as follows\n```\n" - for table_info in tables_info.tables: - self.tables_prompt += f'Table "{table_info.name}"' - if table_info.comment: - self.tables_prompt += f' comment "{table_info.comment}"' - self.tables_prompt += "\n(\n" - for i, column_info in enumerate(table_info.columns): - self.tables_prompt += f"{column_info.name} {column_info.data_type}" - if column_info.comment: - self.tables_prompt += f' comment "{column_info.comment}"' - if i != len(table_info.columns) - 1: - self.tables_prompt += "," - self.tables_prompt += "\n" - self.tables_prompt += ");\n" - self.tables_prompt += "```\n" - return True, "Success" - - -class SQLiteManger(SqlManager): - - def __init__(self, db_file, tables_info_dict: dict): - super().__init__("SQLite", "", "", "", 0, "", {}, "") - super().reset_db("SQLite", f"sqlite:///{db_file}", tables_info_dict) - - -sql_query_instruct_template = """ -Given the following SQL tables and current date {current_date}, your job is to write sql queries in {db_type} given a user’s request. -Alert: Just replay the sql query in a code block. - -{sql_tables} -""" # noqa E501 - -sql_explain_instruct_template = """ -According to chat history -``` -{history_info} -``` - -bellowing sql query is executed - -``` -{sql_query} -``` -the sql result is -``` -{sql_result} -``` -""" - - -class SqlCall(ModuleBase): - def __init__( - self, - llm, - sql_manager: SqlManager, - sql_examples: str = "", - use_llm_for_sql_result=True, - return_trace: bool = False, - ) -> None: - super().__init__(return_trace=return_trace) - self._sql_tool = sql_manager - self._query_prompter = ChatPrompter(instruction=sql_query_instruct_template).pre_hook(self.sql_query_promt_hook) - self._llm_query = llm.share(prompt=self._query_prompter).used_by(self._module_id) - self._answer_prompter = ChatPrompter(instruction=sql_explain_instruct_template).pre_hook( - self.sql_explain_prompt_hook - ) - self._llm_answer = llm.share(prompt=self._answer_prompter).used_by(self._module_id) - self._pattern = re.compile(r"```sql(.+?)```", re.DOTALL) - with pipeline() as sql_execute_ppl: - sql_execute_ppl.exec = self._sql_tool - if use_llm_for_sql_result: - sql_execute_ppl.concate = (lambda q, r: [q, r]) | bind(sql_execute_ppl.input, _0) - sql_execute_ppl.llm_answer = self._llm_answer - with pipeline() as ppl: - ppl.llm_query = self._llm_query - ppl.sql_extractor = self.extract_sql_from_response - with switch(judge_on_full_input=False) as ppl.sw: - ppl.sw.case[False, lambda x: x] - ppl.sw.case[True, sql_execute_ppl] - self._impl = ppl - - def sql_query_promt_hook( - self, - input: Union[str, List, Dict[str, str], None] = None, - history: List[Union[List[str], Dict[str, Any]]] = [], - tools: Union[List[Dict[str, Any]], None] = None, - label: Union[str, None] = None, - ): - current_date = datetime.datetime.now().strftime("%Y-%m-%d") - sql_tables_info = self._sql_tool.get_tables_desc() - if not isinstance(input, str): - raise ValueError(f"Unexpected type for input: {type(input)}") - return ( - dict( - current_date=current_date, db_type=self._sql_tool.db_type, sql_tables=sql_tables_info, user_query=input - ), - history, - tools, - label, - ) - - def sql_explain_prompt_hook( - self, - input: Union[str, List, Dict[str, str], None] = None, - history: List[Union[List[str], Dict[str, Any]]] = [], - tools: Union[List[Dict[str, Any]], None] = None, - label: Union[str, None] = None, - ): - explain_query = "Tell the user based on the sql execution results, making sure to keep the language consistent \ - with the user's input and don't translate original result." - if not isinstance(input, list) and len(input) != 2: - raise ValueError(f"Unexpected type for input: {type(input)}") - assert "root_input" in globals and self._llm_answer._module_id in globals["root_input"] - user_query = globals["root_input"][self._llm_answer._module_id] - globals.pop("root_input") - history_info = chat_history_to_str(history, user_query) - return ( - dict(history_info=history_info, sql_query=input[0], sql_result=input[1], explain_query=explain_query), - history, - tools, - label, - ) - - def extract_sql_from_response(self, str_response: str) -> tuple[bool, str]: - # Remove the triple backticks if present - matches = self._pattern.findall(str_response) - if matches: - # Return the first match - extracted_content = matches[0].strip() - return True, extracted_content - else: - return False, str_response - - def forward(self, input: str, llm_chat_history: List[Dict[str, Any]] = None): - globals["root_input"] = {self._llm_answer._module_id: input} - if self._module_id in globals["chat_history"]: - globals["chat_history"][self._llm_query._module_id] = globals["chat_history"][self._module_id] - return self._impl(input) diff --git a/tests/advanced_tests/standard_test/test_mongodb_manager.py b/tests/advanced_tests/standard_test/test_mongodb_manager.py index ca34ba3b..096d4444 100644 --- a/tests/advanced_tests/standard_test/test_mongodb_manager.py +++ b/tests/advanced_tests/standard_test/test_mongodb_manager.py @@ -85,23 +85,23 @@ class MongoDBEgsData: ] -class TestSqlManager(unittest.TestCase): +class TestMongoDBManager(unittest.TestCase): @classmethod def clean_obsolete_tables(cls, mongodb_manager: MongoDBManager): today = datetime.datetime.now() pattern = r"^(?:america)_(\d{8})_(\w+)" OBSOLETE_DAYS = 2 - db_result = mongodb_manager.get_all_collections() - assert db_result.status == DBStatus.SUCCESS, db_result.detail - existing_collections = db_result.result - for collection_name in existing_collections: - match = re.match(pattern, collection_name) - if not match: - continue - table_create_date = datetime.datetime.strptime(match.group(1), "%Y%m%d") - delta = (today - table_create_date).days - if delta >= OBSOLETE_DAYS: - mongodb_manager.drop_collection(collection_name) + with mongodb_manager.get_client() as client: + db = client[mongodb_manager.db_name] + existing_collections = db.list_collection_names() + for collection_name in existing_collections: + match = re.match(pattern, collection_name) + if not match: + continue + table_create_date = datetime.datetime.strptime(match.group(1), "%Y%m%d") + delta = (today - table_create_date).days + if delta >= OBSOLETE_DAYS: + db.drop_collection(collection_name) @classmethod def setUpClass(cls): @@ -118,8 +118,10 @@ def setUpClass(cls): cls.mongodb_manager = MongoDBManager(username, password, host, port, database, MongoDBEgsData.COLLECTION_NAME) cls.clean_obsolete_tables(cls.mongodb_manager) - cls.mongodb_manager.delete({}) - cls.mongodb_manager.insert(MongoDBEgsData.COLLECTION_DATA) + with cls.mongodb_manager.get_client() as client: + collection = client[cls.mongodb_manager.db_name][cls.mongodb_manager.collection_name] + collection.delete_many({}) + collection.insert_many(MongoDBEgsData.COLLECTION_DATA) cls.mongodb_manager.set_desc( { "summary": "美国各个城市的人口情况", @@ -135,7 +137,9 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): # restore to clean database - cls.mongodb_manager.drop_collection(MongoDBEgsData.COLLECTION_NAME) + with cls.mongodb_manager.get_client() as client: + collection = client[cls.mongodb_manager.db_name][cls.mongodb_manager.collection_name] + collection.drop() def test_manager_status(self): db_result = self.mongodb_manager.check_connection() @@ -143,26 +147,32 @@ def test_manager_status(self): def test_manager_table_delete_insert_query(self): # delete all documents - self.mongodb_manager.delete({}) - db_result = self.mongodb_manager.select({}) - assert db_result.status == DBStatus.SUCCESS, db_result.detail - assert len(db_result.result) == 0 + with self.mongodb_manager.get_client() as client: + collection = client[self.mongodb_manager.db_name][self.mongodb_manager.collection_name] + collection.delete_many({}) + results = list(collection.find({})) + assert len(results) == 0 - # insert one document - self.mongodb_manager.insert(MongoDBEgsData.COLLECTION_DATA[0]) - # insert many documents - self.mongodb_manager.insert(MongoDBEgsData.COLLECTION_DATA[1:]) + # insert one document + collection.insert_one(MongoDBEgsData.COLLECTION_DATA[0]) + # insert many documents + collection.insert_many(MongoDBEgsData.COLLECTION_DATA[1:]) - db_result = self.mongodb_manager.select({}) - assert db_result.status == DBStatus.SUCCESS, db_result.detail - assert len(db_result.result) == len(MongoDBEgsData.COLLECTION_DATA) + results = list(collection.find({})) + assert len(results) == len(MongoDBEgsData.COLLECTION_DATA) def test_select(self): - db_result = self.mongodb_manager.select({"state": "TX"}) - assert db_result.status == DBStatus.SUCCESS, db_result.detail - db_result = self.mongodb_manager.select({"state": "TX"}, projection={"city": True}) - assert db_result.status == DBStatus.SUCCESS, db_result.detail - assert len(db_result.result) == sum([ele["state"] == "TX" for ele in MongoDBEgsData.COLLECTION_DATA]) + with self.mongodb_manager.get_client() as client: + collection = client[self.mongodb_manager.db_name][self.mongodb_manager.collection_name] + results = list(collection.find({"state": "TX"}, projection={"city": True})) + match_count = sum([ele["state"] == "TX" for ele in MongoDBEgsData.COLLECTION_DATA]) + assert len(results) == match_count + + def test_aggregate(self): + with self.mongodb_manager.get_client() as client: + collection = client[self.mongodb_manager.db_name][self.mongodb_manager.collection_name] + results = list(collection.aggregate([{'$group': {'_id': '$state', 'totalPop': {'$sum': '$pop'}}}, {'$match': {'totalPop': {'$gt': 3000000}}}])) + print(f"results: {results}") @unittest.skip("Just run local model in non-charge test") def test_llm_query_online(self): @@ -172,6 +182,7 @@ def test_llm_query_online(self): self.assertIn("NY", str_results) print(f"str_results:\n{str_results}") + # @unittest.skip("temporary skip test") def test_llm_query_local(self): local_llm = lazyllm.TrainableModule("qwen2-72b-instruct-awq").deploy_method(lazyllm.deploy.vllm).start() sql_call = SqlCall(local_llm, self.mongodb_manager, use_llm_for_sql_result=True, return_trace=True) diff --git a/tests/charge_tests/test_sql_tool.py b/tests/charge_tests/test_sql_tool.py index ebfac9b2..ab73986f 100644 --- a/tests/charge_tests/test_sql_tool.py +++ b/tests/charge_tests/test_sql_tool.py @@ -1,5 +1,5 @@ import unittest -from lazyllm.tools import SQLiteManger, SqlCall, SqlManager, DBStatus +from lazyllm.tools import SqlCall, SqlManager, DBStatus import lazyllm from .utils import SqlEgsData, get_db_init_keywords import datetime @@ -12,9 +12,7 @@ def clean_obsolete_tables(cls, sql_manager: SqlManager): today = datetime.datetime.now() pattern = r"^(?:employee|sales)_(\d{8})_(\w+)" OBSOLETE_DAYS = 2 - db_result = sql_manager.get_all_tables() - assert db_result.status == DBStatus.SUCCESS, db_result.detail - existing_tables = db_result.result + existing_tables = sql_manager.get_all_tables() for table_name in existing_tables: match = re.match(pattern, table_name) if not match: @@ -22,11 +20,13 @@ def clean_obsolete_tables(cls, sql_manager: SqlManager): table_create_date = datetime.datetime.strptime(match.group(1), "%Y%m%d") delta = (today - table_create_date).days if delta >= OBSOLETE_DAYS: - sql_manager.drop(table_name) + sql_manager.drop_table(table_name) @classmethod def setUpClass(cls): - cls.sql_managers: list[SqlManager] = [SQLiteManger(":memory:", SqlEgsData.TEST_TABLES_INFO)] + cls.sql_managers: list[SqlManager] = [SqlManager("SQLite", None, None, None, None, db_name=":memory:", + tables_info_dict=SqlEgsData.TEST_TABLES_INFO)] + # MySQL has been tested with online database. for db_type in ["PostgreSQL"]: username, password, host, port, database = get_db_init_keywords(db_type) cls.sql_managers.append( @@ -34,12 +34,11 @@ def setUpClass(cls): ) for sql_manager in cls.sql_managers: cls.clean_obsolete_tables(sql_manager) + for table_name in SqlEgsData.TEST_TABLES: - db_result = sql_manager.delete(table_name) - assert db_result.status == DBStatus.SUCCESS, db_result.detail + sql_manager.execute_commit(f"DELETE FROM {table_name}") for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - db_result = sql_manager.execute(insert_script) - assert db_result.status == DBStatus.SUCCESS, db_result.detail + sql_manager.execute_commit(insert_script) # Recommend to use sensenova, gpt-4o, qwen online model sql_llm = lazyllm.OnlineChatModule(source="qwen") @@ -52,58 +51,55 @@ def tearDownClass(cls): # restore to clean database for sql_manager in cls.sql_managers: for table_name in SqlEgsData.TEST_TABLES: - db_result = sql_manager.drop(table_name) + db_result = sql_manager.drop_table(table_name) assert db_result.status == DBStatus.SUCCESS, db_result.detail def test_manager_status(self): for sql_manager in self.sql_managers: db_result = sql_manager.check_connection() assert db_result.status == DBStatus.SUCCESS, db_result.detail + + def test_manager_orm_operation(self): + for sql_manager in self.sql_managers: + table_name = SqlEgsData.TEST_TABLES[0] + TableCls = sql_manager.get_table_orm_class(table_name) + sql_manager.insert_values(table_name, SqlEgsData.TEST_EMPLOYEE_INSERT_VALS) - def test_manager_table_create_drop(self): + with sql_manager.get_session() as session: + item = session.query(TableCls).filter(TableCls.employee_id == 1111).first() + assert item.name == "四一" + + + def test_manager_create_tables(self): for sql_manager in self.sql_managers: # 1. drop tables for table_name in SqlEgsData.TEST_TABLES: - db_result = sql_manager.drop(table_name) + db_result = sql_manager.drop_table(table_name) assert db_result.status == DBStatus.SUCCESS, db_result.detail - db_result = sql_manager.get_all_tables() - assert db_result.status == DBStatus.SUCCESS, db_result.detail - existing_tables = set(db_result.result) + existing_tables = set(sql_manager.get_all_tables()) for table_name in SqlEgsData.TEST_TABLES: assert table_name not in existing_tables # 2. create table - db_result = sql_manager.reset_table_info_dict(SqlEgsData.TEST_TABLES_INFO) - assert db_result.status == DBStatus.SUCCESS, db_result.detail - + sql_manager.create_tables_by_info(sql_manager.tables_info) # 3. restore rows for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - db_result = sql_manager.execute(insert_script) - assert db_result.status == DBStatus.SUCCESS, db_result.detail + sql_manager.execute_commit(insert_script) def test_manager_table_delete_insert_query(self): # 1. Delete, as rows already exists during setUp for sql_manager in self.sql_managers: for table_name in SqlEgsData.TEST_TABLES: - db_result = sql_manager.delete(table_name) - assert db_result.status == DBStatus.SUCCESS, db_result.detail + sql_manager.execute_commit(f"DELETE FROM {table_name}") str_results = sql_manager.execute_to_json(SqlEgsData.TEST_QUERY_SCRIPTS) self.assertNotIn("销售一部", str_results) # 2. Insert, restore rows for sql_manager in self.sql_managers: for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - db_result = sql_manager.execute(insert_script) - assert db_result.status == DBStatus.SUCCESS, db_result.detail + sql_manager.execute_commit(insert_script) str_results = sql_manager.execute_to_json(SqlEgsData.TEST_QUERY_SCRIPTS) self.assertIn("销售一部", f"Query: {SqlEgsData.TEST_QUERY_SCRIPTS}; result: {str_results}") - def test_get_tables(self): - for sql_manager in self.sql_managers: - db_result = sql_manager.get_all_tables() - assert db_result.status == DBStatus.SUCCESS, db_result.detail - for table_name in SqlEgsData.TEST_TABLES: - self.assertIn(table_name, db_result.result) - def test_llm_query_online(self): for sql_call in self.sql_calls: str_results = sql_call("去年一整年销售额最多的员工是谁,销售额是多少?") diff --git a/tests/charge_tests/utils.py b/tests/charge_tests/utils.py index a5b50c27..d71346e9 100644 --- a/tests/charge_tests/utils.py +++ b/tests/charge_tests/utils.py @@ -54,6 +54,10 @@ class SqlEgsData: f"INSERT INTO {TEST_TABLES[1]} VALUES (2, 4989.23, 5103.22, 4897.98, 5322.05);", f"INSERT INTO {TEST_TABLES[1]} VALUES (11, 5989.23, 6103.22, 2897.98, 3322.05);", ] + TEST_EMPLOYEE_INSERT_VALS = [ + {"employee_id": 1111, "name": "四一", "department": "IT"}, + {"employee_id": 11111, "name": "五一", "department": "IT"} + ] TEST_QUERY_SCRIPTS = f"SELECT department from {TEST_TABLES[0]} WHERE employee_id=1;" class MongoDBEgsData: