From 47cce3a94e5a3987b65c5f3b3d93dd074463e26d Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 23 Oct 2024 11:12:42 +0300 Subject: [PATCH] style fixes --- pyproject.toml | 7 +++- tests/conftest.py | 18 +++++++-- tests/test_connection.py | 30 ++++++++++----- tests/test_cursor.py | 22 ++++++++--- ydb_dbapi/__init__.py | 6 +-- ydb_dbapi/{connection.py => connections.py} | 42 +++++++++++++-------- ydb_dbapi/cursors.py | 23 +++++++---- ydb_dbapi/errors.py | 4 -- ydb_dbapi/utils.py | 6 ++- 9 files changed, 104 insertions(+), 54 deletions(-) rename ydb_dbapi/{connection.py => connections.py} (90%) diff --git a/pyproject.toml b/pyproject.toml index 71fdf5a..194e8fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ydb-dbapi" -version = "0.1.0" +version = "0.0.1" description = "" authors = ["Oleg Ovcharuk "] readme = "README.md" @@ -58,6 +58,8 @@ ignore = [ # Ignores below could be deleted "EM101", # Allow to use string literals in exceptions "TRY003", # Allow specifying long messages outside the exception class + "SLF001", # Allow access private member, + "PGH003", # Allow not to specify rule codes ] select = ["ALL"] @@ -72,7 +74,8 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" force-single-line = true [tool.ruff.lint.per-file-ignores] -"**/test_*.py" = ["S", "SLF", "ANN201", "ARG", "PLR2004"] +"**/test_*.py" = ["S", "SLF", "ANN201", "ARG", "PLR2004", "PT012"] +"conftest.py" = ["ARG001"] "__init__.py" = ["F401", "F403"] [tool.pytest.ini_options] diff --git a/tests/conftest.py b/tests/conftest.py index 07e8dab..a863061 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ from __future__ import annotations +from asyncio import AbstractEventLoop +from collections.abc import AsyncGenerator from collections.abc import Generator from typing import Any from typing import Callable @@ -9,6 +11,7 @@ from testcontainers.core.generic import DbContainer from testcontainers.core.generic import wait_container_is_ready from testcontainers.core.utils import setup_logger +from typing_extensions import Self logger = setup_logger(__name__) @@ -33,7 +36,7 @@ def __init__( self._name = name self._database_name = "local" - def start(self): + def start(self) -> Self: self._maybe_stop_old_container() super().start() return self @@ -115,7 +118,9 @@ def connection_kwargs(ydb_container: YDBContainer) -> dict: @pytest.fixture -async def driver(ydb_container, event_loop): +async def driver( + ydb_container: YDBContainer, event_loop: AbstractEventLoop +) -> AsyncGenerator[ydb.aio.Driver]: driver = ydb.aio.Driver( connection_string=ydb_container.get_connection_string() ) @@ -128,7 +133,9 @@ async def driver(ydb_container, event_loop): @pytest.fixture -async def session_pool(driver: ydb.aio.Driver): +async def session_pool( + driver: ydb.aio.Driver, +) -> AsyncGenerator[ydb.aio.QuerySessionPool]: session_pool = ydb.aio.QuerySessionPool(driver) async with session_pool: await session_pool.execute_with_retries( @@ -146,8 +153,11 @@ async def session_pool(driver: ydb.aio.Driver): yield session_pool + @pytest.fixture -async def session(session_pool: ydb.aio.QuerySessionPool): +async def session( + session_pool: ydb.aio.QuerySessionPool, +) -> AsyncGenerator[ydb.aio.QuerySession]: session = await session_pool.acquire() yield session diff --git a/tests/test_connection.py b/tests/test_connection.py index 4a683d6..2300efd 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator from contextlib import suppress import pytest @@ -37,7 +40,7 @@ async def _test_isolation_level_read_only( await connection.rollback() async with connection.cursor() as cursor: - cursor.execute("DROP TABLE foo") + await cursor.execute("DROP TABLE foo") async def _test_connection(self, connection: dbapi.Connection) -> None: await connection.commit() @@ -66,7 +69,9 @@ async def _test_connection(self, connection: dbapi.Connection) -> None: await cur.execute("DROP TABLE foo") await cur.close() - async def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None: + async def _test_cursor_raw_query( + self, connection: dbapi.Connection + ) -> None: cur = connection.cursor() assert cur @@ -107,7 +112,10 @@ async def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None: async def _test_errors(self, connection: dbapi.Connection) -> None: with pytest.raises(dbapi.InterfaceError): - await dbapi.connect("localhost:2136", database="/local666") + await dbapi.connect( + "localhost:2136", # type: ignore + database="/local666", # type: ignore + ) cur = connection.cursor() @@ -142,8 +150,10 @@ async def _test_errors(self, connection: dbapi.Connection) -> None: class TestAsyncConnection(BaseDBApiTestSuit): @pytest_asyncio.fixture - async def connection(self, connection_kwargs): - conn = await dbapi.connect(**connection_kwargs) + async def connection( + self, connection_kwargs: dict + ) -> AsyncGenerator[dbapi.Connection]: + conn = await dbapi.connect(**connection_kwargs) # ignore: typing try: yield conn finally: @@ -166,19 +176,21 @@ async def test_isolation_level_read_only( isolation_level: str, read_only: bool, connection: dbapi.Connection, - ): + ) -> None: await self._test_isolation_level_read_only( connection, isolation_level, read_only ) @pytest.mark.asyncio - async def test_connection(self, connection: dbapi.Connection): + async def test_connection(self, connection: dbapi.Connection) -> None: await self._test_connection(connection) @pytest.mark.asyncio - async def test_cursor_raw_query(self, connection: dbapi.Connection): + async def test_cursor_raw_query( + self, connection: dbapi.Connection + ) -> None: await self._test_cursor_raw_query(connection) @pytest.mark.asyncio - async def test_errors(self, connection: dbapi.Connection): + async def test_errors(self, connection: dbapi.Connection) -> None: await self._test_errors(connection) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 2ddd8b2..481d782 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,9 +1,10 @@ import pytest import ydb_dbapi +from ydb.aio import QuerySession @pytest.mark.asyncio -async def test_cursor_ddl(session): +async def test_cursor_ddl(session: QuerySession) -> None: cursor = ydb_dbapi.Cursor(session=session) yql = """ @@ -27,7 +28,7 @@ async def test_cursor_ddl(session): @pytest.mark.asyncio -async def test_cursor_dml(session): +async def test_cursor_dml(session: QuerySession) -> None: cursor = ydb_dbapi.Cursor(session=session) yql_text = """ INSERT INTO table (id, val) VALUES @@ -48,12 +49,13 @@ async def test_cursor_dml(session): await cursor.execute(query=yql_text) res = await cursor.fetchone() + assert res is not None assert len(res) == 1 assert res[0] == 3 @pytest.mark.asyncio -async def test_cursor_fetch_one(session): +async def test_cursor_fetch_one(session: QuerySession) -> None: cursor = ydb_dbapi.Cursor(session=session) yql_text = """ INSERT INTO table (id, val) VALUES @@ -73,16 +75,18 @@ async def test_cursor_fetch_one(session): await cursor.execute(query=yql_text) res = await cursor.fetchone() + assert res is not None assert res[0] == 1 res = await cursor.fetchone() + assert res is not None assert res[0] == 2 assert await cursor.fetchone() is None @pytest.mark.asyncio -async def test_cursor_fetch_many(session): +async def test_cursor_fetch_many(session: QuerySession) -> None: cursor = ydb_dbapi.Cursor(session=session) yql_text = """ INSERT INTO table (id, val) VALUES @@ -104,15 +108,18 @@ async def test_cursor_fetch_many(session): await cursor.execute(query=yql_text) res = await cursor.fetchmany() + assert res is not None assert len(res) == 1 assert res[0][0] == 1 res = await cursor.fetchmany(size=2) + assert res is not None assert len(res) == 2 assert res[0][0] == 2 assert res[1][0] == 3 res = await cursor.fetchmany(size=2) + assert res is not None assert len(res) == 1 assert res[0][0] == 4 @@ -120,7 +127,7 @@ async def test_cursor_fetch_many(session): @pytest.mark.asyncio -async def test_cursor_fetch_all(session): +async def test_cursor_fetch_all(session: QuerySession) -> None: cursor = ydb_dbapi.Cursor(session=session) yql_text = """ INSERT INTO table (id, val) VALUES @@ -143,6 +150,7 @@ async def test_cursor_fetch_all(session): assert cursor.rowcount == 3 res = await cursor.fetchall() + assert res is not None assert len(res) == 3 assert res[0][0] == 1 assert res[1][0] == 2 @@ -152,13 +160,14 @@ async def test_cursor_fetch_all(session): @pytest.mark.asyncio -async def test_cursor_next_set(session): +async def test_cursor_next_set(session: QuerySession) -> None: cursor = ydb_dbapi.Cursor(session=session) yql_text = """SELECT 1 as val; SELECT 2 as val;""" await cursor.execute(query=yql_text) res = await cursor.fetchall() + assert res is not None assert len(res) == 1 assert res[0][0] == 1 @@ -166,6 +175,7 @@ async def test_cursor_next_set(session): assert nextset res = await cursor.fetchall() + assert res is not None assert len(res) == 1 assert res[0][0] == 2 diff --git a/ydb_dbapi/__init__.py b/ydb_dbapi/__init__.py index 93d9e23..475d377 100644 --- a/ydb_dbapi/__init__.py +++ b/ydb_dbapi/__init__.py @@ -1,5 +1,5 @@ -from .connection import Connection -from .connection import IsolationLevel -from .connection import connect +from .connections import Connection +from .connections import IsolationLevel +from .connections import connect from .cursors import Cursor from .errors import * diff --git a/ydb_dbapi/connection.py b/ydb_dbapi/connections.py similarity index 90% rename from ydb_dbapi/connection.py rename to ydb_dbapi/connections.py index bce7e03..4a6c084 100644 --- a/ydb_dbapi/connection.py +++ b/ydb_dbapi/connections.py @@ -50,9 +50,7 @@ def __init__( "ydb_session_pool" in self.conn_kwargs ): # Use session pool managed manually self._shared_session_pool = True - self._session_pool = self.conn_kwargs.pop( - "ydb_session_pool" - ) + self._session_pool = self.conn_kwargs.pop("ydb_session_pool") self._driver = self._session_pool._driver else: self._shared_session_pool = False @@ -127,13 +125,24 @@ class Connection(BaseYDBConnection): _ydb_driver_class = ydb.aio.Driver _ydb_session_pool_class = ydb.aio.QuerySessionPool - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + host: str = "", + port: str = "", + database: str = "", + **conn_kwargs: Any, + ) -> None: + super().__init__( + host, + port, + database, + **conn_kwargs, + ) self._session: ydb.aio.QuerySession | None = None - self._tx_context: ydb.QueryTxContext | None = None + self._tx_context: ydb.aio.QueryTxContext | None = None - async def _wait(self, timeout: int = 5) -> None: + async def wait_ready(self, timeout: int = 5) -> None: try: await self._driver.wait(timeout, fail_fast=True) except ydb.Error as e: @@ -144,13 +153,13 @@ async def _wait(self, timeout: int = 5) -> None: "Failed to connect to YDB, details " f"{self._driver.discovery_debug_details()}" ) - raise InterfaceError( - msg - ) from e + raise InterfaceError(msg) from e self._session = await self._session_pool.acquire() - def cursor(self): + def cursor(self) -> Cursor: + if self._session is None: + raise RuntimeError("Connection is not ready, use wait_ready.") if self._current_cursor and not self._current_cursor._closed: raise RuntimeError( "Unable to create new Cursor before closing existing one." @@ -218,12 +227,13 @@ async def callee() -> None: await self._driver.scheme_client.describe_path(table_path) await retry_operation_async(callee) - return True except ydb.SchemeError: return False + else: + return True async def _get_table_names(self, abs_dir_path: str) -> list[str]: - async def callee(): + async def callee() -> ydb.Directory: return await self._driver.scheme_client.list_directory( abs_dir_path ) @@ -239,7 +249,7 @@ async def callee(): return result -async def connect(*args, **kwargs) -> Connection: - conn = Connection(*args, **kwargs) - await conn._wait() +async def connect(*args: tuple, **kwargs: dict) -> Connection: + conn = Connection(*args, **kwargs) # type: ignore + await conn.wait_ready() return conn diff --git a/ydb_dbapi/cursors.py b/ydb_dbapi/cursors.py index 51ff139..7d87c03 100644 --- a/ydb_dbapi/cursors.py +++ b/ydb_dbapi/cursors.py @@ -2,11 +2,13 @@ import itertools from collections.abc import AsyncIterator +from collections.abc import Generator from collections.abc import Iterator from typing import Any from typing import Union import ydb +from typing_extensions import Self from .errors import DatabaseError from .errors import Error @@ -53,20 +55,18 @@ def __init__( self._state = CursorStatus.ready @property - def description(self): + def description(self) -> list[tuple] | None: return self._description @property - def rowcount(self): + def rowcount(self) -> int: return self._rows_count @handle_ydb_errors async def _execute_generic_query( self, query: str, parameters: ParametersType | None = None ) -> AsyncIterator[ydb.convert.ResultSet]: - return await self._session.execute( - query=query, parameters=parameters - ) + return await self._session.execute(query=query, parameters=parameters) @handle_ydb_errors async def _execute_transactional_query( @@ -126,7 +126,9 @@ def _update_description(self, result_set: ydb.convert.ResultSet) -> None: for col in result_set.columns ] - def _rows_iterable(self, result_set): + def _rows_iterable( + self, result_set: ydb.convert.ResultSet + ) -> Generator[tuple]: try: for row in result_set.rows: # returns tuple to be compatible with SqlAlchemy and because @@ -212,8 +214,13 @@ def _check_cursor_closed(self) -> None: "Could not perform operation: Cursor is closed." ) - async def __aenter__(self) -> Cursor: + async def __aenter__(self) -> Self: return self - async def __aexit__(self, exc_type, exc, tb) -> None: + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object, + ) -> None: await self.close() diff --git a/ydb_dbapi/errors.py b/ydb_dbapi/errors.py index e5463b6..011e16c 100644 --- a/ydb_dbapi/errors.py +++ b/ydb_dbapi/errors.py @@ -4,10 +4,6 @@ from google.protobuf.message import Message -class Warning(Exception): - pass - - class Error(Exception): def __init__( self, diff --git a/ydb_dbapi/utils.py b/ydb_dbapi/utils.py index 55ea9cb..d11f418 100644 --- a/ydb_dbapi/utils.py +++ b/ydb_dbapi/utils.py @@ -1,5 +1,7 @@ import functools from enum import Enum +from typing import Any +from typing import Callable import ydb @@ -12,9 +14,9 @@ from .errors import ProgrammingError -def handle_ydb_errors(func): +def handle_ydb_errors(func: Callable) -> Callable: @functools.wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: tuple, **kwargs: dict) -> Any: try: return await func(*args, **kwargs) except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: