From a3e8ab342aa2bb6bbf55d93e7b77052547fe542f Mon Sep 17 00:00:00 2001 From: Jesus Lara Date: Thu, 24 Oct 2024 16:41:53 +0200 Subject: [PATCH] test copy into table methods for asyncpg --- asyncdb/drivers/pg.py | 98 ++++++++++++++++------------------------ asyncdb/version.py | 2 +- examples/test_asyncdb.py | 47 ++++++++++--------- 3 files changed, 62 insertions(+), 85 deletions(-) diff --git a/asyncdb/drivers/pg.py b/asyncdb/drivers/pg.py index 798ddc71..bb388b37 100644 --- a/asyncdb/drivers/pg.py +++ b/asyncdb/drivers/pg.py @@ -12,6 +12,7 @@ from collections.abc import Callable, Iterable from typing import Any, Optional, Union from dataclasses import is_dataclass +import contextlib from datamodel import BaseModel import asyncpg from asyncpg.exceptions import ( @@ -945,10 +946,12 @@ async def transaction(self): async def commit(self): if self._transaction: await self._transaction.commit() + self._transaction = None async def rollback(self): if self._transaction: await self._transaction.rollback() + self._transaction = None async def cursor(self, sentence: Union[str, any], params: Iterable[Any] = None, **kwargs): # pylint: disable=W0236 if not sentence: @@ -993,6 +996,38 @@ async def __anext__(self): raise StopAsyncIteration ## COPY Functions + @contextlib.asynccontextmanager + async def handle_copy_errors(self, operation_name: str): + try: + yield + except ( + QueryCanceledError, + StatementError, + UniqueViolationError, + ForeignKeyViolationError, + NotNullViolationError + ) as err: + self._logger.warning( + f"AsyncPg {operation_name}: {err}" + ) + raise + except UndefinedTableError as ex: + raise StatementError( + f"Error {operation_name}, table doesn't exist: {ex}" + ) from ex + except UndefinedColumnError as ex: + raise StatementError( + f"Error {operation_name}, Undefined Column: {ex}" + ) from ex + except (InvalidSQLStatementNameError, PostgresSyntaxError) as ex: + raise StatementError( + f"Error {operation_name}: Invalid Statement: {ex}" + ) from ex + except Exception as ex: + raise DriverError( + f"Error {operation_name}: {ex}" + ) from ex + ## type: [ text, csv, binary ] async def copy_from_table(self, table="", schema="public", output=None, file_type="csv", columns=None): """table_copy @@ -1002,7 +1037,7 @@ async def copy_from_table(self, table="", schema="public", output=None, file_typ """ if not self._connection: await self.connection() - try: + async with self.handle_copy_errors("Copy From Table"): result = await self._connection.copy_from_table( table_name=table, schema_name=schema, @@ -1011,23 +1046,6 @@ async def copy_from_table(self, table="", schema="public", output=None, file_typ output=output, ) return result - except ( - QueryCanceledError, - StatementError, - UniqueViolationError, - ForeignKeyViolationError, - NotNullViolationError - ) as err: - self._logger.warning( - f"AsyncPg Copy From Table: {err}" - ) - raise - except UndefinedTableError as ex: - raise StatementError(f"Error on Copy, Table {table }doesn't exists: {ex}") from ex - except (InvalidSQLStatementNameError, PostgresSyntaxError, UndefinedColumnError) as ex: - raise StatementError(f"Error on Copy, Invalid Statement Error: {ex}") from ex - except Exception as ex: - raise DriverError(f"Error on Table Copy: {ex}") from ex async def copy_to_table(self, table="", schema="public", source=None, file_type="csv", columns=None): """copy_to_table @@ -1040,7 +1058,7 @@ async def copy_to_table(self, table="", schema="public", source=None, file_type= if self._transaction: # a transaction exists: await self._transaction.commit() - try: + async with self.handle_copy_errors("Copy To Table"): result = await self._connection.copy_to_table( table_name=table, schema_name=schema, @@ -1049,25 +1067,6 @@ async def copy_to_table(self, table="", schema="public", source=None, file_type= source=source, ) return result - except ( - QueryCanceledError, - StatementError, - UniqueViolationError, - ForeignKeyViolationError, - NotNullViolationError - ) as err: - self._logger.warning( - f"AsyncPg Copy To Table: {err}" - ) - raise - except UndefinedTableError as ex: - raise StatementError( - f"Error on Copy to Table {table } doesn't exists: {ex}") from ex - except (InvalidSQLStatementNameError, PostgresSyntaxError, UndefinedColumnError) as ex: - raise StatementError( - f"Error on Copy, Invalid Statement Error: {ex}") from ex - except Exception as ex: - raise DriverError(f"Error on Copy to Table {ex}") from ex async def copy_into_table(self, table="", schema="public", source=None, columns=None): """copy_into_table @@ -1080,32 +1079,11 @@ async def copy_into_table(self, table="", schema="public", source=None, columns= if self._transaction: # a transaction exists: await self._transaction.commit() - try: + async with self.handle_copy_errors("Copy Into Table"): result = await self._connection.copy_records_to_table( table_name=table, schema_name=schema, columns=columns, records=source ) return result - except ( - QueryCanceledError, - StatementError, - UniqueViolationError, - ForeignKeyViolationError, - NotNullViolationError - ) as err: - self._logger.warning( - f"AsyncPg Copy Into Table: {err}" - ) - raise - except UndefinedTableError as ex: - raise StatementError(f"Error on Copy to Table {table } doesn't exists: {ex}") from ex - except (InvalidSQLStatementNameError, PostgresSyntaxError, UndefinedColumnError) as ex: - raise StatementError(f"Error on Copy, Invalid Statement Error: {ex}") from ex - except InterfaceError as ex: - raise DriverError(f"Error on Copy into Table Function: {ex}") from ex - except (RuntimeError, PostgresError) as ex: - raise DriverError(f"Postgres Error on Copy into Table: {ex}") from ex - except Exception as ex: - raise DriverError(f"Error on Copy into Table: {ex}") from ex ## Model Logic: async def column_info(self, tablename: str, schema: str = None): diff --git a/asyncdb/version.py b/asyncdb/version.py index a30a0164..512fe956 100644 --- a/asyncdb/version.py +++ b/asyncdb/version.py @@ -3,7 +3,7 @@ __title__ = "asyncdb" __description__ = "Library for Asynchronous data source connections \ Collection of asyncio drivers." -__version__ = "2.9.5" +__version__ = "2.9.6" __author__ = "Jesus Lara" __author_email__ = "jesuslarag@gmail.com" __license__ = "BSD" diff --git a/examples/test_asyncdb.py b/examples/test_asyncdb.py index b672fe1b..70424ffe 100644 --- a/examples/test_asyncdb.py +++ b/examples/test_asyncdb.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- from asyncdb import AsyncDB, AsyncPool -from asyncdb.meta import asyncORM from asyncdb.exceptions import NoDataFound, ProviderError, StatementError """ @@ -26,7 +25,7 @@ loop = asyncio.get_event_loop() asyncio.set_event_loop(loop) -asyncpg_url = "postgres://troc_pgdata:12345678@127.0.0.1:5432/navigator_dev" +asyncpg_url = "postgres://troc_pgdata:12345678@127.0.0.1:5432/navigator" pool = AsyncPool("pg", dsn=asyncpg_url, loop=loop) loop.run_until_complete(pool.connect()) @@ -44,7 +43,7 @@ def adb(): if pool.is_connected(): #db = asyncio.get_running_loop().run_until_complete(dbpool.acquire()) db = loop.run_until_complete(pool.acquire()) - return asyncORM(db=db) + return db def sharing_token(token): db = adb() @@ -97,26 +96,28 @@ async def connect(c): result, error = await conn.execute("SET TIMEZONE TO 'America/New_York'") await t.commit() # table copy - await c.copy_from_table( - table="stores", - schema="walmart", - columns=["store_id", "store_name"], - output="stores.csv", - ) + async with await c.transaction() as t: + await t.copy_from_table( + table="stores", + schema="walmart", + columns=["store_id", "store_name"], + output="stores.csv", + ) # copy from file to table # TODO: repair error io.UnsupportedOperation: read - # await c.copy_to_table(table = 'stores', schema = 'test', columns = [ 'store_id', 'store_name'], source = '/home/jesuslara/proyectos/navigator-next/stores.csv') - # copy from asyncpg records - # try: - # await c.copy_into_table( - # table="stores", - # schema="test", - # columns=["store_id", "store_name"], - # source=stores, - # ) - # except (StatementError, ProviderError) as err: - # print(str(err)) - # return False + async with await t.transaction() as t: + await t.copy_to_table(table = 'stores', schema = 'test', columns = [ 'store_id', 'store_name'], source = '/home/jesuslara/proyectos/navigator-next/stores.csv') + # copy from asyncpg records + try: + await c.copy_into_table( + table="stores", + schema="test", + columns=["store_id", "store_name"], + source=stores, + ) + except (StatementError, ProviderError) as err: + print(str(err)) + return False async def prepared(p): @@ -128,11 +129,9 @@ async def prepared(p): if __name__ == '__main__': try: - a = sharing_token('67C1BEE8DDC0BB873930D04FAF16B338F8CB09490571F8901E534937D4EFA8EE33230C435BDA93B7C7CEBA67858C4F70321A0D92201947F13278F495F92DDC0BE5FDFCF0684704C78A3E7BA5133ACADBE2E238F25D568AEC4170EB7A0BE819CE8F758B890855E5445EB22BE52439FA377D00C9E4225BC6DAEDD2DAC084446E7F697BF1CEC129DFB84FA129B7B8881C66EEFD91A0869DAE5D71FD5055FCFF75') - print(a.columns()) # # test: first with db connected: e = AsyncDB("pg", dsn=asyncpg_url, loop=loop) loop.run_until_complete(connect(e)) # loop.run_until_complete(prepared(e)) finally: - pool.terminate() + loop.close()