From 9a0a28869ea863854c9093f1f0c087bde5c505d2 Mon Sep 17 00:00:00 2001 From: Jesus Lara Date: Fri, 10 Jan 2025 02:13:41 +0100 Subject: [PATCH 1/2] new MongoDB driver using Motor --- asyncdb/drivers/aioch.py | 2 +- asyncdb/drivers/bigquery.py | 76 ++-- asyncdb/drivers/clickhouse.py | 2 +- asyncdb/drivers/duckdb.py | 2 +- asyncdb/drivers/dummy.py | 3 +- asyncdb/drivers/elastic.py | 3 +- asyncdb/drivers/influx.py | 2 +- asyncdb/drivers/mongo.py | 591 ++++++++++++++++++++++--------- asyncdb/drivers/mredis.py | 9 +- asyncdb/drivers/mysql.py | 5 +- asyncdb/drivers/mysqlclient.py | 4 +- asyncdb/drivers/odbc.py | 2 +- asyncdb/drivers/pg.py | 4 +- asyncdb/drivers/postgres.py | 2 +- asyncdb/drivers/redis.py | 5 +- asyncdb/drivers/sa.py | 2 +- asyncdb/drivers/sqlite.py | 2 +- asyncdb/interfaces/connection.py | 8 +- asyncdb/models/model.py | 9 +- asyncdb/version.py | 3 +- examples/test_mongo.py | 41 ++- setup.py | 4 +- tests/test_mongo.py | 208 ++++++++++- 23 files changed, 740 insertions(+), 249 deletions(-) diff --git a/asyncdb/drivers/aioch.py b/asyncdb/drivers/aioch.py index eea6cfc5..44636c98 100644 --- a/asyncdb/drivers/aioch.py +++ b/asyncdb/drivers/aioch.py @@ -29,7 +29,7 @@ class aioch(SQLDriver): _provider: str = "clickhouse" _syntax: str = "sql" - _dsn: str = "{database}" + _dsn_template: str = "{database}" _test_query: str = "SELECT version()" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: diff --git a/asyncdb/drivers/bigquery.py b/asyncdb/drivers/bigquery.py index aa1dbf1d..0d105a10 100644 --- a/asyncdb/drivers/bigquery.py +++ b/asyncdb/drivers/bigquery.py @@ -11,7 +11,7 @@ import pandas as pd from google.cloud import storage from google.cloud import bigquery as bq -from google.cloud.exceptions import Conflict +from google.cloud.exceptions import Conflict, NotFound from google.cloud.bigquery import LoadJobConfig, SourceFormat from google.oauth2 import service_account from .sql import SQLDriver @@ -25,6 +25,7 @@ class bigquery(SQLDriver, ModelBackend): _provider = "bigquery" _syntax = "sql" _test_query = "SELECT 1" + _dsn_template: str = "" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: self._credentials = params.get("credentials", None) @@ -157,33 +158,67 @@ async def create_table(self, dataset_id, table_id, schema): self._logger.info(f"Created table {table.project}.{table.dataset_id}.{table.table_id}") return table except Conflict: - self._logger.warning(f"Table {table.project}.{table.dataset_id}.{table.table_id} already exists") + self._logger.warning( + f"Table {table.project}.{table.dataset_id}.{table.table_id} already exists" + ) return table except Exception as e: - raise DriverError(f"BigQuery: Error creating table: {e}") + raise DriverError(f"BigQuery: Error creating table: {e}") from e async def truncate_table(self, table_id: str, dataset_id: str): """ - Truncate a BigQuery table by overwriting with an empty table. + Truncate a BigQuery table by overwriting it with an empty table. + + Parameters: + dataset_id (str): The ID of the dataset containing the table. + table_id (str): The ID of the table to truncate. + + Raises: + DriverError: If there is an issue truncating the table. """ if not self._connection: await self.connection() - # Construct a reference to the dataset - dataset_ref = bq.DatasetReference(self._connection.project, dataset_id) - table_ref = dataset_ref.table(table_id) - table = self._connection.get_table(table_ref) # API request to fetch the table schema + try: + # Reference to the dataset and table + dataset_ref = self._connection.dataset(dataset_id) + table_ref = dataset_ref.table(table_id) + + # Ensure the table exists + try: + table = self._connection.get_table(table_ref) + except NotFound: + raise DriverError( + f"BigQuery: Table `{dataset_id}.{table_id}` does not exist." + ) - # Create an empty table with the same schema - job_config = bq.QueryJobConfig(destination=table_ref) - job_config.write_disposition = bq.WriteDisposition.WRITE_TRUNCATE + # Configure the query job to overwrite the table + job_config = bq.QueryJobConfig( + destination=table_ref, + write_disposition=bq.WriteDisposition.WRITE_TRUNCATE, + allow_large_results=True + ) - try: - job = self._connection.query(f"SELECT * FROM `{table_ref}` WHERE FALSE", job_config=job_config) - job.result() # Wait for the job to finish - self._logger.info(f"Truncated table {dataset_id}.{table_id}") + # Execute a query that selects no rows, effectively truncating the table + query = f"SELECT * FROM `{self._project_id}.{dataset_id}.{table_id}` WHERE FALSE" + + self._logger.debug(f"Truncating table with query: {query}") + job = self._connection.query(query, job_config=job_config) + + # Wait for the job to complete + await asyncio.get_event_loop().run_in_executor(None, job.result) + + self._logger.info(f"Successfully truncated table `{dataset_id}.{table_id}`.") + return True + except DriverError: + raise except Exception as e: - raise DriverError(f"BigQuery: Error truncating table: {e}") + self._logger.error( + f"BigQuery: Error truncating table `{dataset_id}.{table_id}`: {e}" + ) + raise DriverError( + f"BigQuery: Error truncating table `{dataset_id}.{table_id}`: {e}" + ) from e async def query(self, sentence: str, **kwargs): if not self._connection: @@ -244,8 +279,7 @@ async def fetch_all(self, query, *args): """ Fetch all results from a BigQuery query """ - results = await self.execute(query, *args) - return results + return await self.execute(query, *args) async def fetch_one(self, query, *args): """ @@ -273,7 +307,7 @@ async def write( table = f"{self._project_id}.{dataset_id}.{table_id}" try: if isinstance(data, pd.DataFrame): - if use_pandas is True: + if use_pandas: job = await self._thread_func(self._connection.load_table_from_dataframe, data, table, **kwargs) else: object_cols = data.select_dtypes(include=["object"]).columns @@ -293,7 +327,7 @@ async def write( dataset_ref = self._connection.dataset(dataset_id) table_ref = dataset_ref.table(table_id) table = bq.Table(table_ref) - if use_streams is True: + if use_streams: errors = await self._thread_func(self._connection.insert_rows_json, table, data, **kwargs) if errors: raise RuntimeError(f"Errors occurred while inserting rows: {errors}") @@ -314,7 +348,7 @@ async def write( # return Job object return job except Exception as e: - raise DriverError(f"BigQuery: Error writing to table: {e}") + raise DriverError(f"BigQuery: Error writing to table: {e}") from e async def load_table_from_uri( self, diff --git a/asyncdb/drivers/clickhouse.py b/asyncdb/drivers/clickhouse.py index e48f688c..bfe26249 100644 --- a/asyncdb/drivers/clickhouse.py +++ b/asyncdb/drivers/clickhouse.py @@ -29,7 +29,7 @@ class clickhouse(SQLDriver): _provider: str = "clickhouse" _syntax: str = "sql" - _dsn: str = "" + _dsn_template: str = "" _test_query: str = "SELECT now(), version()" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: diff --git a/asyncdb/drivers/duckdb.py b/asyncdb/drivers/duckdb.py index 1c54c77d..455b09b6 100644 --- a/asyncdb/drivers/duckdb.py +++ b/asyncdb/drivers/duckdb.py @@ -45,7 +45,7 @@ async def fetch_all(self) -> Iterable[Sequence]: class duckdb(SQLDriver, DBCursorBackend): _provider: str = "duckdb" _syntax: str = "sql" - _dsn: str = "{database}" + _dsn_template: str = "{database}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: SQLDriver.__init__(self, dsn, loop, params, **kwargs) diff --git a/asyncdb/drivers/dummy.py b/asyncdb/drivers/dummy.py index 4b56588c..cbdacce7 100644 --- a/asyncdb/drivers/dummy.py +++ b/asyncdb/drivers/dummy.py @@ -9,10 +9,11 @@ class dummy(BaseDriver): _provider = "dummy" _syntax = "sql" + _dsn_template: str = "test:/{host}:{port}/{db}" + def __init__(self, dsn: Union[str, None] = None, loop=None, params: dict = None, **kwargs): self._test_query = "SELECT 1" - self._dsn = "test:/{host}:{port}/{db}" if not params: params = {"host": "127.0.0.1", "port": "0", "db": 0} try: diff --git a/asyncdb/drivers/elastic.py b/asyncdb/drivers/elastic.py index 99beeaad..45e5abc9 100644 --- a/asyncdb/drivers/elastic.py +++ b/asyncdb/drivers/elastic.py @@ -33,6 +33,8 @@ def get_dsn(self) -> str: class elastic(BaseDriver): _provider = "elasticsearch" _syntax = "json" + _dsn_template: str = "{protocol}://{host}:{port}/" + def __init__(self, dsn: str = None, loop=None, params: Union[dict, ElasticConfig] = None, **kwargs): # self._dsn = "{protocol}://{user}:{password}@{host}:{port}/{database}" @@ -40,7 +42,6 @@ def __init__(self, dsn: str = None, loop=None, params: Union[dict, ElasticConfig self._database = params.database else: self._database = params.pop("db", "default") - self._dsn = "{protocol}://{host}:{port}/" super(elastic, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) def create_dsn(self, params: Union[dict, dataclass]): diff --git a/asyncdb/drivers/influx.py b/asyncdb/drivers/influx.py index 9d68d52d..1aab04af 100644 --- a/asyncdb/drivers/influx.py +++ b/asyncdb/drivers/influx.py @@ -39,12 +39,12 @@ def retry(self, conf: tuple[str, str, str], data: str, exception: InfluxDBError) class influx(InitDriver, ConnectionDSNBackend): _provider = "influxdb" _syntax = "sql" + _dsn_template: str = "{protocol}://{host}:{port}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: self._test_query = "SELECT 1" self._query_raw = "SELECT {fields} FROM {table} {where_cond}" self._version: str = None - self._dsn = "{protocol}://{host}:{port}" self._client = InfluxDBClientAsync self._enable_gzip = kwargs.get("enable_gzip", True) self._retries = Retry(connect=5, read=2, redirect=5) diff --git a/asyncdb/drivers/mongo.py b/asyncdb/drivers/mongo.py index 16b1ba11..1ee0a5d3 100644 --- a/asyncdb/drivers/mongo.py +++ b/asyncdb/drivers/mongo.py @@ -1,10 +1,12 @@ -from typing import Optional, Any -from collections.abc import Iterable, Sequence +from typing import Optional, Any, Union, Iterable, List +from collections.abc import Sequence import asyncio import time import motor.motor_asyncio import pymongo import pandas as pd +import pyarrow as pa +from dataclasses import is_dataclass, asdict from ..exceptions import ( ConnectionTimeout, DataError, @@ -19,30 +21,45 @@ class mongo(BaseDriver): """ - mongo Driver class for interacting with MongoDB asynchronously using motor. + MongoDB Driver class for interacting with MongoDB asynchronously using Motor. Attributes: ----------- _provider : str Name of the database provider ('mongodb'). - _syntax : None - Not applicable for MongoDB. + _syntax : str + Syntax type, set to 'mongo'. _dsn : str Data Source Name (DSN) for connecting to the database, if provided. _connection : motor.motor_asyncio.AsyncIOMotorClient Holds the active connection to the database. + _database : motor.motor_asyncio.AsyncIOMotorDatabase + Reference to the selected database. _connected : bool Indicates if the driver is currently connected to the database. + _database_name : str + Name of the default database to use. + _databases : list + List of available databases. + _timeout : int + Connection timeout in seconds. """ _provider = "mongodb" - _dsn = "'mongodb://{host}:{port}" + _dsn_template = "mongodb://{username}:{password}@{host}:{port}/{database}" _syntax = "mongo" _parameters = () _initialized_on = None _timeout: int = 5 - def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: + + def __init__( + self, + dsn: str = "", + loop: asyncio.AbstractEventLoop = None, + params: dict = None, + **kwargs + ) -> None: """ Initializes the MongoDBDriver with the given DSN, event loop, and optional parameters. @@ -56,78 +73,147 @@ def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params params : dict, optional Additional connection parameters as a dictionary. Defaults to None. kwargs : dict - Additional keyword arguments to pass to the base SQLDriver. - """ - if "username" in params: - self._dsn = "mongodb://{username}:{password}@{host}:{port}" - if "database" in params: - self._dsn = self._dsn + "/{database}" - self._database_name = params.get("database", kwargs.get("database", None)) - super().__init__(dsn, loop, params, **kwargs) + Additional keyword arguments to pass to the base Driver. + """ self._connection = None self._database = None - self._databases: list = [] + self._databases: List[str] = [] + self._database_name = params.get("database", kwargs.get("database", None)) + super(mongo, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) + self._dsn = self._construct_dsn(params) + + def _construct_dsn(self, params) -> str: + """Construct DSN based on provided parameters.""" + if not self._params: + return "" + username = params.get("username") + password = params.get("password") + host = params.get("host", "localhost") + port = params.get("port", 27017) + database = self._database_name or "" + if username and password: + if database: + return self._dsn_template.format( + username=username, + password=password, + host=host, + port=port, + database=database, + ) + f"?authSource={database}" + else: + return f"mongodb://{username}:{password}@{host}:{port}/?authSource=admin" + else: + return f"mongodb://{host}:{port}/{database}" - async def connection(self): + async def _select_database(self) -> motor.motor_asyncio.AsyncIOMotorDatabase: + """ + Internal method to select the database. + + Returns: + -------- + motor.motor_asyncio.AsyncIOMotorDatabase + The selected database instance. + + Raises: + ------- + DriverError + If the database cannot be selected. """ - Get a connection + if self._database is None: + if self._database_name: + self._database = self._connection[self._database_name] + else: + raise DriverError( + "No database selected. Use 'use' method to select a database." + ) + return self._database + + async def connection(self) -> "mongo": + """ + Establishes a connection to the MongoDB server. + + Returns: + -------- + mongo + Returns the instance of the driver itself. + + Raises: + ------- + DriverError + If there is an issue establishing the connection. """ self._connection = None self._connected = False try: if self._dsn: - self._connection = motor.motor_asyncio.AsyncIOMotorClient(self._dsn) + self._connection = motor.motor_asyncio.AsyncIOMotorClient( + self._dsn, + serverSelectionTimeoutMS=self._timeout * 1000 + ) else: - params = {"host": self._params.get("host", "localhost"), "port": self._params.get("port", 27017)} - if "username" in self._params: + params = { + "host": self._params.get("host", "localhost"), + "port": self._params.get("port", 27017), + "serverSelectionTimeoutMS": self._timeout * 1000, + } + if "username" in self._params and "password" in self._params: params["username"] = self._params["username"] params["password"] = self._params["password"] self._connection = motor.motor_asyncio.AsyncIOMotorClient(**params) - try: - self._databases = await self._connection.list_database_names() - except Exception as err: - raise DriverError(f"Error Connecting to Mongo: {err}") from err - if len(self._databases) > 0: - self._connected = True - self._initialized_on = time.time() + # Attempt to fetch server info to verify connection + await self._connection.admin.command('ping') + self._connected = True + self._initialized_on = time.time() return self except Exception as err: self._connection = None - self._cursor = None - print(err) - raise DriverError(f"connection Error, Terminated: {err}") from err + self._database = None + raise DriverError(f"Connection Error, Terminated: {err}") from err - async def close(self): + async def close(self) -> None: """ - Closing a Connection + Closes the connection to the MongoDB server. + + Raises: + ------- + DriverError + If there is an issue closing the connection. """ try: if self._connection: - try: - self._connection.close() - except Exception as err: - self._connection = None - raise DriverError(f"Connection Error, Terminated: {err}") + self._connection.close() except Exception as err: - raise DriverError(f"Close Error: {err}") + raise DriverError(f"Close Error: {err}") from err finally: self._connection = None + self._database = None self._connected = False - async def test_connection(self): + def is_connected(self): + return self._connected + + async def test_connection(self, use_ping: bool = False) -> list: """ - Getting information about Server. + Tests the connection by retrieving server information. Returns: -------- - [result, error] : list + list A list containing the server information and any error that occurred. """ error = None result = None if self._connection: + if use_ping: + try: + result = await self._connection.admin.command("ping") + self._connected = True + return [result, error] + except Exception as err: + error = err try: result = await self._connection.server_info() + self._connected = True except Exception as err: error = err finally: @@ -136,7 +222,18 @@ async def test_connection(self): error = DriverError("Not connected to MongoDB") return [None, error] - async def use(self, database: str): + async def prepare(self, *args, **kwargs) -> None: + """ + Prepares a statement. MongoDB does not support prepared statements. + + Raises: + ------- + DriverError + Indicating that prepared statements are not supported. + """ + raise DriverError("MongoDB does not support prepared statements.") + + async def use(self, database: str) -> motor.motor_asyncio.AsyncIOMotorDatabase: """ Switches the current database to the specified one. @@ -147,15 +244,30 @@ async def use(self, database: str): Returns: -------- - None + motor.motor_asyncio.AsyncIOMotorDatabase + The selected database instance. + + Raises: + ------- + DriverError + If the connection is not established. """ - if self._connection: - self._database = self._connection[database] - return self._database - else: - raise DriverError(f"Not connected to MongoDB on DB {database}") + if not self._connection: + raise DriverError( + f"Not connected to MongoDB. Cannot switch to database '{database}'." + ) + self._database = self._connection[database] + self._database_name = database + return self._database + - async def execute(self, collection_name: str, operation: str, *args, **kwargs) -> Optional[Any]: + async def execute( + self, + collection_name: str, + operation: str, + *args, + **kwargs + ) -> Optional[Any]: """ Executes an operation (insert, update, delete) on a collection asynchronously. @@ -179,18 +291,21 @@ async def execute(self, collection_name: str, operation: str, *args, **kwargs) - """ error = None result = None - if not self._database: - raise DriverError("No database selected. Use 'use' method to select a database.") - - collection = self._database[collection_name] try: + db = await self._select_database() + collection = db[collection_name] method = getattr(collection, operation) result = await method(*args, **kwargs) except Exception as err: error = err return (result, error) - async def execute_many(self, collection_name: str, operation: str, documents: list) -> Optional[Any]: + async def execute_many( + self, + collection_name: str, + operation: str, + documents: list + ) -> Optional[Any]: """ Executes a bulk operation on a collection asynchronously. @@ -209,29 +324,18 @@ async def execute_many(self, collection_name: str, operation: str, documents: li Optional[Any] The result of the bulk operation, if any. """ - error = None - result = None - if not self._database: - raise DriverError("No database selected. Use 'use' method to select a database.") - - collection = self._database[collection_name] - try: - method = getattr(collection, operation) - result = await method(documents) - except Exception as err: - error = err - return (result, error) + return await self.execute(collection_name, operation, documents) executemany = execute_many - async def __aenter__(self) -> Any: + async def __aenter__(self) -> "mongo": """ Asynchronous context manager entry. Establishes a connection when entering the context. Returns: -------- - self : MongoDBDriver + mongo Returns the instance of the driver itself. Raises: @@ -239,11 +343,7 @@ async def __aenter__(self) -> Any: DriverError If an error occurs during connection establishment. """ - try: - await self.connection() - except Exception as err: - error = f"Error on Connection: {err}" - raise DriverError(message=error) from err + await self.connection() return self async def __aexit__(self, exc_type, exc, tb): @@ -266,7 +366,13 @@ async def __aexit__(self, exc_type, exc, tb): """ await self.close() - async def query(self, collection_name: str, filter: dict = None, *args, **kwargs) -> Iterable[Any]: + async def query( + self, + collection_name: str, + filter: Optional[dict] = None, + *args, + **kwargs + ) -> Iterable[Any]: """ Executes a query to retrieve documents from a collection asynchronously. @@ -286,24 +392,24 @@ async def query(self, collection_name: str, filter: dict = None, *args, **kwargs Iterable[Any] An iterable containing the documents returned by the query. """ - error = None - result = None - if not self._database: - self._database = self.use(database=self._database_name) - if not self._database: - raise DriverError("No database selected. Use 'use' method to select it.") - - collection = self._database[collection_name] - cursor = collection.find(filter or {}, *args, **kwargs) - result = [] try: + db = await self._select_database() + collection = db[collection_name] + cursor = collection.find(filter or {}, *args, **kwargs) + result = [] async for document in cursor: result.append(document) + return await self._serializer(result, None) except Exception as err: - error = err - return await self._serializer(result, error) + return await self._serializer(None, err) - async def queryrow(self, collection_name: str, filter: dict = None, *args, **kwargs) -> Optional[dict]: + async def queryrow( + self, + collection_name: str, + filter: Optional[dict] = None, + *args, + **kwargs + ) -> Optional[dict]: """ Executes a query to retrieve a single document from a collection asynchronously. @@ -323,21 +429,21 @@ async def queryrow(self, collection_name: str, filter: dict = None, *args, **kwa Optional[dict] The document returned by the query, or None if no document matches. """ - error = None - result = None - if not self._database: - self._database = self.use(database=self._database_name) - if not self._database: - raise DriverError("No database selected. Use 'use' method to select it.") - - collection = self._database[collection_name] try: + db = await self._select_database() + collection = db[collection_name] result = await collection.find_one(filter or {}, *args, **kwargs) + return await self._serializer(result, None) except Exception as err: - error = err - return await self._serializer(result, error) + return await self._serializer(None, err) - async def fetch(self, collection_name: str, filter: dict = None, *args, **kwargs) -> Iterable[Any]: + async def fetch( + self, + collection_name: str, + filter: Optional[dict] = None, + *args, + **kwargs + ) -> Iterable[Any]: """ Executes a query to retrieve documents from a collection asynchronously. @@ -357,25 +463,28 @@ async def fetch(self, collection_name: str, filter: dict = None, *args, **kwargs Iterable[Any] An iterable containing the documents returned by the query. """ - result = None - if not self._database: - self._database = self.use(database=self._database_name) - if not self._database: - raise DriverError("No database selected. Use 'use' method to select it.") - - collection = self._database[collection_name] - cursor = collection.find(filter or {}, *args, **kwargs) + # return await self.query(collection_name, filter, *args, **kwargs) + error = None result = [] try: + db = await self._select_database() + collection = db[collection_name] + cursor = collection.find(filter or {}, *args, **kwargs) async for document in cursor: result.append(document) + return (result, None) except Exception as err: - raise DriverError(f"Error Getting Data from Mongo {err}") - return result + return (None, err) fetch_all = fetch - async def fetch_one(self, collection_name: str, filter: dict = None, *args, **kwargs) -> Optional[dict]: + async def fetch_one( + self, + collection_name: str, + filter: Optional[dict] = None, + *args, + **kwargs + ) -> Optional[dict]: """ Executes a query to retrieve a single document from a collection asynchronously. @@ -395,99 +504,247 @@ async def fetch_one(self, collection_name: str, filter: dict = None, *args, **kw Optional[dict] The document returned by the query, or None if no document matches. """ - result = None - if not self._database: - self._database = self.use(database=self._database_name) - if not self._database: - raise DriverError("No database selected. Use 'use' method to select it.") - - collection = self._database[collection_name] - try: - result = await collection.find_one(filter or {}, *args, **kwargs) - except Exception as err: - raise DriverError(f"No row to be returned {err}") - return result + return await self.queryrow(collection_name, filter, *args, **kwargs) fetchrow = fetch_one fetchone = fetch_one async def write( self, - data, - table: str = None, - database: str = None, + data: Union[Iterable[dict], pd.DataFrame, pa.Table, Any], + collection: str = None, + database: Optional[str] = None, use_pandas: bool = True, - if_exists: str = "replace", + if_exists: str = "append", **kwargs, ) -> bool: """ Writes data to a collection asynchronously, - with upsert functionality. + supporting 'append' (insert) and 'replace' (upsert) operations. Parameters: ----------- - data : Iterable or pandas DataFrame - The data to be written, which can be any iterable of documents or pandas DataFrame. - table : str, optional + data : Iterable[dict] | pd.DataFrame | pa.Table | Any + The data to be written, which can be any iterable of documents, + pandas DataFrame, Arrow Table, or dataclass instances. + collection : str, optional The name of the collection where the data will be written. database : str, optional The name of the database where the collection resides. use_pandas : bool, optional - If True, uses pandas DataFrame to process the data. Defaults to True. + If True, uses pandas DataFrame or Arrow Table to process the data. Defaults to True. if_exists : str, optional - Specifies what to do if the document already exists ('replace'). - Defaults to 'replace'. + Specifies what to do if the document already exists ('replace' or 'append'). + Defaults to 'append'. kwargs : dict - Additional keyword arguments, e.g., key_field for upsert identification. + Additional keyword arguments, e.g., `key_field` for upsert identification. Returns: -------- bool - Returns True if the write operation is successful, otherwise False. + Returns True if the write operation is successful. + + Raises: + ------- + ValueError + If invalid parameters are provided. + DriverError + If an error occurs during the write operation. """ # Ensure database is selected if database: await self.use(database) - if not self._database: - raise DriverError("No database selected. Use 'use' method to select a database.") + try: + db = await self._select_database() + except DriverError as e: + raise e - if not table: - raise ValueError("No collection (table) specified.") + if not collection: + raise ValueError("No collection specified for write operation.") - collection = self._database[table] + coll = db[collection] - # Process data - if use_pandas and isinstance(data, pd.DataFrame): - # Assume data is a pandas DataFrame - documents = data.to_dict("records") - elif isinstance(data, Iterable): - # Assume data is an iterable of documents - documents = list(data) - else: - raise ValueError("Mongo: Data must be an iterable or a pandas DataFrame.") + # Process data based on type + try: + if use_pandas and isinstance(data, pd.DataFrame): + documents = data.to_dict("records") + elif use_pandas and isinstance(data, pa.Table): + documents = [dict(zip(data.schema.names, row)) for row in data.to_pydict().values()] + elif is_dataclass(data): + documents = [asdict(item) for item in data] if isinstance(data, Sequence) else asdict(data) + elif isinstance(data, Iterable): + documents = list(data) + else: + raise ValueError("Mongo: Data must be an iterable of dicts, pandas DataFrame, Arrow Table, or dataclass instances.") + except Exception as e: + raise DataError(f"Error processing input data: {e}") from e # Get key_field from kwargs or default to '_id' key_field = kwargs.get("key_field", "_id") - - # Build bulk operations operations = [] - if if_exists == "replace": - for doc in documents: - if key_field not in doc: - # If key_field is not in document, generate a unique identifier - doc[key_field] = pymongo.ObjectId() - filter = {key_field: doc[key_field]} - operations.append(pymongo.UpdateOne(filter, {"$set": doc}, upsert=True)) - elif if_exists == "append": - # Insert new documents without checking for existing ones - operations = [pymongo.InsertOne(doc) for doc in documents] - else: - raise ValueError("Invalid value for if_exists: choose 'replace' or 'append'") + try: + if if_exists == "replace": + for doc in documents: + if key_field not in doc: + # If key_field is not in document, generate a unique identifier + doc[key_field] = pymongo.ObjectId() + filter_condition = {key_field: doc[key_field]} + operations.append(pymongo.UpdateOne(filter_condition, {"$set": doc}, upsert=True)) + elif if_exists == "append": + # Insert new documents without checking for existing ones + operations = [pymongo.InsertOne(doc) for doc in documents] + else: + raise ValueError("Invalid value for if_exists: choose 'replace' or 'append'") + except Exception as e: + raise DataError(f"Error preparing bulk operations: {e}") from e # Execute bulk write try: - result = await collection.bulk_write(operations, ordered=False) - return result + if not operations: + raise DataError("No operations to perform during write.") + result = await coll.bulk_write(operations, ordered=False) + self._logger.info(f"Write operation successful: {result.bulk_api_result}") + return True + except Exception as e: + raise DriverError(f"Error during write operation: {e}") from e + + async def truncate_table(self, collection_name: str) -> bool: + """ + Truncates a collection by deleting all documents within it. + + Parameters: + ----------- + collection_name : str + The name of the collection to truncate. + + Returns: + -------- + bool + Returns True if the truncation is successful. + + Raises: + ------- + DriverError + If there is an issue truncating the collection. + """ + try: + db = await self._select_database() + collection = db[collection_name] + result = await collection.delete_many({}) + self._logger.info( + f"Truncated collection '{collection_name}': Deleted {result.deleted_count} documents." + ) + return True + except Exception as e: + raise DriverError( + f"Error truncating collection '{collection_name}': {e}" + ) from e + + async def delete( + self, + collection_name: str, + filter: Optional[dict] = None, + many: bool = False + ) -> int: + """ + Deletes documents from a collection based on a filter condition. + + Parameters: + ----------- + collection_name : str + The name of the collection to delete from. + filter : dict, optional + The filter criteria for deletion. Defaults to None (delete all documents). + many : bool, optional + If True, deletes multiple documents matching the filter. + If False, deletes a single document matching the filter. + Defaults to False. + + Returns: + -------- + int + The number of documents deleted. + + Raises: + ------- + DriverError + If there is an issue during the deletion process. + """ + try: + db = await self._select_database() + collection = db[collection_name] + if many: + result = await collection.delete_many(filter or {}) + self._logger.info( + f"Deleted {result.deleted_count} documents from '{collection_name}' with filter {filter}." + ) + else: + result = await collection.delete_one(filter or {}) + self._logger.info( + f"Deleted {result.deleted_count} document from '{collection_name}' with filter {filter}." + ) + return result.deleted_count + except Exception as e: + raise DriverError( + f"Error deleting documents from '{collection_name}': {e}" + ) from e + + async def drop_collection(self, collection_name: str) -> bool: + """ + Drops a collection from the current database. + + Parameters: + ----------- + collection_name : str + The name of the collection to drop. + + Returns: + -------- + bool + True if the collection was successfully dropped, False otherwise. + + Raises: + ------- + DriverError + If there is an issue dropping the collection. + """ + try: + db = await self._select_database() + result = await db.drop_collection(collection_name) + self._logger.info(f"Dropped collection '{collection_name}': {result}") + return True + except Exception as e: + raise DriverError( + f"Error dropping collection '{collection_name}': {e}" + ) from e + + + async def drop_database(self, database_name: str) -> bool: + """ + Drops a database from the MongoDB server. + + Parameters: + ----------- + database_name : str + The name of the database to drop. + + Returns: + -------- + bool + True if the database was successfully dropped, False otherwise. + + Raises: + ------- + DriverError + If there is an issue dropping the database. + """ + try: + if not self._connection: + raise DriverError("Not connected to MongoDB.") + result = await self._connection.drop_database(database_name) + self._logger.info(f"Dropped database '{database_name}': {result}") + return True except Exception as e: - # Handle exception - raise DriverError(f"Error during write operation: {e}") + raise DriverError( + f"Error dropping database '{database_name}': {e}" + ) from e diff --git a/asyncdb/drivers/mredis.py b/asyncdb/drivers/mredis.py index f957f4f8..c9336754 100644 --- a/asyncdb/drivers/mredis.py +++ b/asyncdb/drivers/mredis.py @@ -19,16 +19,13 @@ class mredis(InitDriver, ConnectionDSNBackend): _provider = "redis" _syntax = "json" _encoding = "utf-8" + _dsn_template: str = "redis://{host}:{port}/{db}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: - self._dsn = "redis://{host}:{port}/{db}" InitDriver.__init__(self, loop=loop, params=params, **kwargs) ConnectionDSNBackend.__init__(self, dsn=dsn, params=params) - try: - self._encoding = params["encoding"] - del params["encoding"] - except KeyError: - pass + self._encoding = params.pop('encoding', "utf-8") + ### Context magic Methods def __enter__(self): diff --git a/asyncdb/drivers/mysql.py b/asyncdb/drivers/mysql.py index 83c94748..57fb1533 100644 --- a/asyncdb/drivers/mysql.py +++ b/asyncdb/drivers/mysql.py @@ -25,6 +25,7 @@ class mysqlCursor(SQLCursor): class mysqlPool(BasePool): _setup_func: Optional[Callable] = None _init_func: Optional[Callable] = None + _dsn_template: str = "mysql://{user}:{password}@{host}:{port}/{database}" def __init__( self, dsn: str = None, loop: asyncio.AbstractEventLoop = None, params: Optional[dict] = None, **kwargs @@ -32,7 +33,6 @@ def __init__( self._test_query = "SELECT 1" self._max_clients = 300 self._min_size = 10 - self._dsn = "mysql://{user}:{password}@{host}:{port}/{database}" self._init_command = kwargs.pop("init_command", None) self._sql_modes = kwargs.pop("sql_modes", None) super(mysqlPool, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) @@ -166,9 +166,10 @@ class mysql(SQLDriver, DBCursorBackend): _provider = "mysql" _syntax = "sql" _test_query = "SELECT 1" + _dsn_template: str = "mysql://{user}:{password}@{host}:{port}/{database}" + def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: - self._dsn = "mysql://{user}:{password}@{host}:{port}/{database}" self._prepared = None self._cursor = None self._transaction = None diff --git a/asyncdb/drivers/mysqlclient.py b/asyncdb/drivers/mysqlclient.py index 95f2c6bb..56b4bf2b 100644 --- a/asyncdb/drivers/mysqlclient.py +++ b/asyncdb/drivers/mysqlclient.py @@ -25,6 +25,7 @@ class mysqlCursor(SQLCursor): class mysqlclientPool(BasePool): _setup_func: Optional[Callable] = None _init_func: Optional[Callable] = None + _dsn_template: str = "mysql://{user}:{password}@{host}:{port}/{database}" def __init__( self, dsn: str = None, loop: asyncio.AbstractEventLoop = None, params: Optional[dict] = None, **kwargs @@ -32,7 +33,6 @@ def __init__( self._test_query = "SELECT 1" self._max_clients = 30 self._min_size = 10 - self._dsn = "mysql://{user}:{password}@{host}:{port}/{database}" self._init_command = kwargs.pop("init_command", None) self._sql_modes = kwargs.pop("sql_modes", None) self._executor = ThreadPoolExecutor(max_workers=self._min_size) @@ -188,9 +188,9 @@ class mysqlclient(SQLDriver, DBCursorBackend): _provider = "mysql" _syntax = "sql" _test_query = "SELECT 1" + _dsn_template: str = "mysql://{user}:{password}@{host}:{port}/{database}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: - self._dsn = "mysql://{user}:{password}@{host}:{port}/{database}" self._prepared = None self._cursor = None self._transaction = None diff --git a/asyncdb/drivers/odbc.py b/asyncdb/drivers/odbc.py index d11138cf..f2dafe67 100644 --- a/asyncdb/drivers/odbc.py +++ b/asyncdb/drivers/odbc.py @@ -32,7 +32,7 @@ async def __aenter__(self) -> "odbcCursor": class odbc(SQLDriver, DBCursorBackend): _provider = "odbc" - _dsn = "Driver={driver};Database={database}" + _dsn_template = "Driver={driver};Database={database}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: if "host" in params: diff --git a/asyncdb/drivers/pg.py b/asyncdb/drivers/pg.py index 1454f6ad..391a0975 100644 --- a/asyncdb/drivers/pg.py +++ b/asyncdb/drivers/pg.py @@ -91,6 +91,7 @@ class pgPool(BasePool): _setup_func: Optional[Callable] = None _init_func: Optional[Callable] = None + _dsn_template: str = "postgres://{user}:{password}@{host}:{port}/{database}" def __init__( self, dsn: str = None, loop: asyncio.AbstractEventLoop = None, params: Optional[dict] = None, **kwargs @@ -100,7 +101,6 @@ def __init__( self._max_clients = 300 self._min_size = 10 self._server_settings = {} - self._dsn = "postgres://{user}:{password}@{host}:{port}/{database}" super(pgPool, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) self._custom_record: bool = False custom_record = kwargs.get("custom_record", False) @@ -448,9 +448,9 @@ class pg(SQLDriver, DBCursorBackend, ModelBackend): _provider = "pg" _syntax = "sql" _test_query = "SELECT 1" + _dsn_template: str = "postgres://{user}:{password}@{host}:{port}/{database}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: - self._dsn = "postgres://{user}:{password}@{host}:{port}/{database}" self.application_name = os.getenv("APP_NAME", "NAV") self._prepared = None self._cursor = None diff --git a/asyncdb/drivers/postgres.py b/asyncdb/drivers/postgres.py index 4446b854..f98a9681 100644 --- a/asyncdb/drivers/postgres.py +++ b/asyncdb/drivers/postgres.py @@ -49,9 +49,9 @@ class postgres(threading.Thread, SQLDriver): _provider = "postgres" _syntax = "sql" _test_query = "SELECT 1" + _dsn_template: str = "postgres://{user}:{password}@{host}:{port}/{database}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: - self._dsn = "postgres://{user}:{password}@{host}:{port}/{database}" self.application_name = os.getenv("APP_NAME", "NAV") self._is_started = False self._error = None diff --git a/asyncdb/drivers/redis.py b/asyncdb/drivers/redis.py index 10a6b7f7..6e68197b 100644 --- a/asyncdb/drivers/redis.py +++ b/asyncdb/drivers/redis.py @@ -25,10 +25,11 @@ class RedisConfig: class redisPool(BasePool): + _dsn_template = "redis://{host}:{port}/{db}" + def __init__( self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: Union[dict, RedisConfig] = None, **kwargs ) -> None: - self._dsn = "redis://{host}:{port}/{db}" super(redisPool, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) async def connect(self, **kwargs): @@ -120,9 +121,9 @@ async def execute(self, sentence, *args, **kwargs): class redis(BaseDriver): _provider = "redis" _syntax = "json" + _dsn_template: str = "redis://{host}:{port}/{db}" def __init__(self, dsn: str = None, loop=None, params: dict = None, **kwargs): - self._dsn = "redis://{host}:{port}/{db}" super(redis, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) if "connection" in kwargs: self._connection = kwargs["connection"] diff --git a/asyncdb/drivers/sa.py b/asyncdb/drivers/sa.py index 9f258956..fcb684f4 100644 --- a/asyncdb/drivers/sa.py +++ b/asyncdb/drivers/sa.py @@ -57,6 +57,7 @@ class sa(SQLDriver, DBCursorBackend): } setup_func: Optional[Callable] = None init_func: Optional[Callable] = None + _dsn_template: str = "{driver}://{user}:{password}@{host}:{port}/{database}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs): """sa. @@ -68,7 +69,6 @@ def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params params (dict, optional): Connection Parameters. Defaults to None. """ self._session = None - self._dsn = "{driver}://{user}:{password}@{host}:{port}/{database}" self._transaction = None self._driver = "postgresql" self.__cursor__ = None diff --git a/asyncdb/drivers/sqlite.py b/asyncdb/drivers/sqlite.py index 6a182315..e1ee912c 100644 --- a/asyncdb/drivers/sqlite.py +++ b/asyncdb/drivers/sqlite.py @@ -28,7 +28,7 @@ async def __aenter__(self) -> "sqliteCursor": class sqlite(SQLDriver, DBCursorBackend, ModelBackend): _provider: str = "sqlite" _syntax: str = "sql" - _dsn: str = "{database}" + _dsn_template: str = "{database}" def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None: SQLDriver.__init__(self, dsn, loop, params, **kwargs) diff --git a/asyncdb/interfaces/connection.py b/asyncdb/interfaces/connection.py index 741fdb2e..bccdb6c7 100644 --- a/asyncdb/interfaces/connection.py +++ b/asyncdb/interfaces/connection.py @@ -105,12 +105,10 @@ class ConnectionDSNBackend(ABC): """ Interface for Databases with DSN Support. """ + _dsn_template: str def __init__(self, dsn: str = None, params: Optional[dict] = None) -> None: - if dsn: - self._dsn = dsn - else: - self._dsn = self.create_dsn(params) + self._dsn = dsn or self.create_dsn(params) try: self._params = params.copy() except (TypeError, AttributeError, ValueError): @@ -118,7 +116,7 @@ def __init__(self, dsn: str = None, params: Optional[dict] = None) -> None: def create_dsn(self, params: dict): try: - return self._dsn.format_map(SafeDict(**params)) if params else None + return self._dsn_template.format_map(SafeDict(**params)) if params else None except TypeError as err: self._logger.error(err) raise DriverError(f"Error creating DSN connection: {err}") from err diff --git a/asyncdb/models/model.py b/asyncdb/models/model.py index 3c1274be..cd2cb01f 100644 --- a/asyncdb/models/model.py +++ b/asyncdb/models/model.py @@ -14,8 +14,13 @@ from datamodel.abstract import Meta from datamodel.exceptions import ValidationError from datamodel.types import MODEL_TYPES, DB_TYPES - -from asyncdb.exceptions import ConnectionMissing, NoDataFound, DriverError, ModelError, StatementError +from asyncdb.exceptions import ( + ConnectionMissing, + NoDataFound, + DriverError, + ModelError, + StatementError +) from asyncdb.utils.modules import module_exists DB_TYPES[int64] = "bigint" diff --git a/asyncdb/version.py b/asyncdb/version.py index 74a9f1d1..8b667cf1 100644 --- a/asyncdb/version.py +++ b/asyncdb/version.py @@ -3,7 +3,8 @@ __title__ = "asyncdb" __description__ = "Library for Asynchronous data source connections \ Collection of asyncio drivers." -__version__ = "2.9.11" +__version__ = "2.10.1" +__copyright__ = "Copyright (c) 2020-2024 Jesus Lara" __author__ = "Jesus Lara" __author_email__ = "jesuslarag@gmail.com" __license__ = "BSD" diff --git a/examples/test_mongo.py b/examples/test_mongo.py index 40e3ff0d..a8127577 100644 --- a/examples/test_mongo.py +++ b/examples/test_mongo.py @@ -1,33 +1,36 @@ -from asyncdb import AsyncDB import asyncio +from asyncdb import AsyncDB +from asyncdb.drivers.mongo import mongo -loop = asyncio.get_event_loop() -asyncio.set_event_loop(loop) params = { "host": "127.0.0.1", "port": "27017", "username": 'troc_pgdata', - "password": '12345678' + "password": '12345678', + "database": "navigator" } -DRIVER='mongo' +async def test_connect(params): + db = AsyncDB('mongo', params=params) + async with await db.connection() as conn: + print('CONNECTED: ', conn.is_connected() is True) + result, error = await conn.test_connection() + print(result, error) + print(type(result) == list) -async def test_connect(driver, params, event_loop): - db = AsyncDB(driver, params=params, loop=event_loop) - await db.connection() - print('CONNECTED: ', db.is_connected() is True) - result, error = await db.test_connection() - print(result, error) - print(type(result) == list) - await db.close() +async def check_connection(): + async with mongo( + params=params + ) as db_driver: + is_connected = await db_driver.test_connection() + if is_connected: + print("Successfully connected to MongoDB.") + else: + print("Failed to connect to MongoDB.") if __name__ == '__main__': - try: - loop.run_until_complete(test_connect(DRIVER, params, loop)) - except Exception as err: - print(err) - finally: - loop.close() + asyncio.run(test_connect(params)) + asyncio.run(check_connection()) diff --git a/setup.py b/setup.py index 23472048..9ca11c27 100644 --- a/setup.py +++ b/setup.py @@ -225,7 +225,7 @@ def readme(): ], "mongodb": [ "pymongo==4.10.1", - "motor==3.5.1", + "motor==3.6.0", ], "msqlserver": [ "pymssql==2.3.1", @@ -280,7 +280,7 @@ def readme(): "sqlalchemy[asyncio]==2.0.34", "elasticsearch[async]==8.15.1", "pymongo==4.10.1", - "motor==3.5.1", + "motor==3.6.0", "pymssql==2.3.1", "aiocouch==3.0.0", "asyncmy==0.2.9", diff --git a/tests/test_mongo.py b/tests/test_mongo.py index 586bf4a8..0c6d971e 100644 --- a/tests/test_mongo.py +++ b/tests/test_mongo.py @@ -1,27 +1,35 @@ import pytest -from asyncdb import AsyncDB, AsyncPool +import re +from pathlib import Path +import uuid +import pandas as pd +from asyncdb import AsyncDB +from asyncdb.drivers.mongo import mongo +from asyncdb.exceptions import DriverError import asyncio -import asyncpg from io import BytesIO -from pathlib import Path -@pytest.fixture + +@pytest.fixture(scope="session") def event_loop(): - loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) yield loop loop.close() DRIVER='mongo' +MONGODB_VERSION='8.0.4' PARAMS = { "host": "127.0.0.1", "port": "27017", - "username": 'troc_pgdata', - "password": '12345678' + "username": 'troc_mongodata', + "password": '12345678', + "database": "admin" } +# Fixture to establish a connection to MongoDB @pytest.fixture async def conn(event_loop): db = AsyncDB(DRIVER, params=PARAMS, loop=event_loop) @@ -29,6 +37,47 @@ async def conn(event_loop): yield db await db.close() +# Fixture to create a dummy database +@pytest.fixture +async def dummy_db(event_loop): + # Generate a unique database name using UUID + db_name = f"test_db_{uuid.uuid4().hex}" + db_driver = mongo( + params={ + "host": "127.0.0.1", + "port": 27017, + "username": 'troc_mongodata', + "password": '12345678', + } + ) + await db_driver.connection() + assert db_driver.is_connected() is True + await db_driver.use(db_name) + yield db_driver + # Teardown: Drop the dummy database after test + await db_driver.drop_database(db_name) + await db_driver.close() + +# Fixture to create a sample collection within the dummy database +@pytest.fixture +async def sample_collection(dummy_db): + collection_name = "test_collection" + # No need to explicitly create the collection; it will be created upon first insert + yield (dummy_db, collection_name) + # Teardown: Drop the collection after test + await dummy_db.drop_collection(collection_name) + + +# Fixture to create a sample pandas DataFrame with two rows +@pytest.fixture +def sample_dataframe(): + data = { + 'name': ['Alice', 'Bob'], + 'age': [30, 25] + } + df = pd.DataFrame(data) + return df + pytestmark = pytest.mark.asyncio @pytest.mark.parametrize("driver", [ @@ -47,6 +96,149 @@ async def test_connect(driver, event_loop): pytest.assume(db.is_connected() is True) result, error = await db.test_connection() pytest.assume(type(result) == dict) - pytest.assume(result['version'] == '4.4.2') + pytest.assume(result['version'] == MONGODB_VERSION) + + # Check if 'version' matches the pattern 'major.minor.patch' + version_pattern = re.compile(r'^\d+\.\d+\.\d+$') + assert 'version' in result, "Version key not found in result." + assert version_pattern.match(result['version']), f"Version format is incorrect: {result['version']}" + await db.close() pytest.assume(db.is_connected() is False) + + +@pytest.mark.asyncio +async def test_is_connected_success(): + async with mongo( + params=PARAMS + ) as db_driver: + # Initially connected via context manager + assert db_driver.is_connected() == True + +@pytest.mark.asyncio +async def test_is_connected_failure(): + # Connection should fail + try: + async with mongo( + params={ + "host": "localhost", + "port": 27017, + "database": "navigator", + "username": "wrong_user", + "password": "wrong_pass" + } + ) as db_driver: + # Connection should fail + await db_driver.connection() + except DriverError as e: + assert 'Authentication failed' in str(e) + else: + pytest.fail("DriverError was not raised when connecting with wrong credentials.") + +@pytest.mark.asyncio +async def test_is_connected_after_close(): + db_driver = mongo( + params=PARAMS + ) + await db_driver.connection() + assert db_driver.is_connected() == True + await db_driver.close() + assert db_driver.is_connected() == False + + +@pytest.mark.asyncio +async def test_write_dataframe(sample_collection, sample_dataframe): + """ + Test writing a pandas DataFrame to a MongoDB collection. + """ + db_driver, collection_name = sample_collection + + # Use the 'write' method to insert the DataFrame into the collection + write_success = await db_driver.write(data=sample_dataframe, collection=collection_name) + assert write_success is True, "Failed to write DataFrame to MongoDB collection." + + # Verify that the data was written correctly + # Fetch all documents from the collection + documents, error = await db_driver.fetch(collection_name) + assert error is None, f"Error during fetch: {error}" + assert len(documents) == 2, "Number of documents in the collection does not match the DataFrame rows." + + # Convert the fetched documents to a list of dictionaries + fetched_data = documents # Assuming 'documents' is already a list of dicts + + # Strip '_id' from each fetched document + fetched_data_stripped = [{k: v for k, v in doc.items() if k != '_id'} for doc in fetched_data] + + # Convert the DataFrame to a list of dictionaries for comparison + expected_data = sample_dataframe.to_dict('records') + + # Sort both lists for consistent ordering + fetched_data_sorted = sorted(fetched_data_stripped, key=lambda x: x['name']) + expected_data_sorted = sorted(expected_data, key=lambda x: x['name']) + + # Assert that the fetched data matches the expected data + assert fetched_data_sorted == expected_data_sorted, "Data mismatch between DataFrame and MongoDB collection." + +@pytest.mark.asyncio +async def test_drop_collection(sample_collection): + """ + Test dropping a collection from the MongoDB database. + """ + db_driver, collection_name = sample_collection + + # Ensure the collection exists by inserting a document + sample_doc = {"name": "Charlie", "age": 28} + write_success = await db_driver.write(data=[sample_doc], collection=collection_name) + assert write_success is True, "Failed to write sample document to MongoDB collection." + + # Verify that the document exists + documents, error = await db_driver.fetch(collection_name) + assert error is None, f"Error during fetch: {error}" + assert len(documents) == 1, "Sample document was not inserted correctly." + + # Drop the collection + drop_success = await db_driver.drop_collection(collection_name) + assert drop_success is True, "Failed to drop MongoDB collection." + + # Verify that the collection has been dropped + try: + documents, error = await db_driver.fetch(collection_name) + if error: + assert "NamespaceNotFound" in str(error), "Unexpected error when fetching dropped collection." + else: + assert len(documents) == 0, "Collection was not dropped successfully." + except DriverError as e: + # If the collection does not exist, ensure the correct error is raised + assert "NamespaceNotFound" in str(e), "Unexpected error when fetching dropped collection." + + +@pytest.mark.asyncio +async def test_drop_database(dummy_db): + """ + Test dropping a MongoDB database. + """ + db_driver = dummy_db + db_name = db_driver._database_name # Accessing the database name from the driver + + # Ensure the database exists by creating a collection + collection_name = "temp_collection" + sample_doc = {"name": "Diana", "age": 22} + write_success = await db_driver.write(data=[sample_doc], collection=collection_name) + assert write_success is True, "Failed to write sample document to MongoDB collection." + + # Verify that the document exists + documents, error = await db_driver.fetch(collection_name) + assert error is None, f"Error during fetch: {error}" + assert len(documents) == 1, "Sample document was not inserted correctly." + + # Drop the database + drop_success = await db_driver.drop_database(db_name) + assert drop_success is True, "Failed to drop MongoDB database." + + # Verify that the database has been dropped by listing databases + try: + databases = await db_driver._connection.list_database_names() + assert db_name not in databases, "Database was not dropped successfully." + except DriverError as e: + # If the database does not exist, ensure the correct error is raised + assert "DatabaseNotFound" in str(e), "Unexpected error when verifying dropped database." From e014c8b0a400cabde233b496b25a82c38be16c9e Mon Sep 17 00:00:00 2001 From: Jesus Lara Date: Fri, 10 Jan 2025 02:30:17 +0100 Subject: [PATCH 2/2] added support for documentDB --- asyncdb/drivers/mongo.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/asyncdb/drivers/mongo.py b/asyncdb/drivers/mongo.py index 1ee0a5d3..7f2ca671 100644 --- a/asyncdb/drivers/mongo.py +++ b/asyncdb/drivers/mongo.py @@ -2,6 +2,7 @@ from collections.abc import Sequence import asyncio import time +from urllib.parse import urlencode import motor.motor_asyncio import pymongo import pandas as pd @@ -79,6 +80,7 @@ def __init__( self._database = None self._databases: List[str] = [] self._database_name = params.get("database", kwargs.get("database", None)) + self._dbtype: str = params.get("dbtype", kwargs.get("dbtype", "mongodb")) super(mongo, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) self._dsn = self._construct_dsn(params) @@ -91,19 +93,34 @@ def _construct_dsn(self, params) -> str: host = params.get("host", "localhost") port = params.get("port", 27017) database = self._database_name or "" + authsource = params.get("authsource", database) or 'admin' if username and password: - if database: - return self._dsn_template.format( - username=username, - password=password, - host=host, - port=port, - database=database, - ) + f"?authSource={database}" - else: - return f"mongodb://{username}:{password}@{host}:{port}/?authSource=admin" + base_dsn = self._dsn_template.format( + username=username, + password=password, + host=host, + port=port, + database=database, + ) else: - return f"mongodb://{host}:{port}/{database}" + base_dsn = f"mongodb://{host}:{port}/{database}" + if self._dbtype == 'mongodb': + return base_dsn + f"?authSource={authsource}" + elif self._dbtype == 'atlas': + return f"{base_dsn}?retryWrites=true&w=majority" + elif self._dbtype == 'documentdb': + more_params = params.get('connection_params', {}) + query_params = { + "ssl": "true", + "replicaSet": params.get("replicaSet", "rs0"), + "readPreference": params.get("readPreference", "secondaryPreferred"), + "retryWrites": params.get("retryWrites", "false"), + "tlsCAFile": params.get("tlsCAFile", "global-bundle.pem"), + **more_params + } + query_string = urlencode(query_params) + return f"{base_dsn}?{query_string}" + return base_dsn async def _select_database(self) -> motor.motor_asyncio.AsyncIOMotorDatabase: """