diff --git a/test/test_core.py b/test/test_core.py index 1dacd09..fa4e326 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -6,7 +6,7 @@ import sqlalchemy as sa import ydb from sqlalchemy import Table, Column, Integer, Unicode -from sqlalchemy.testing.fixtures import TestBase, TablesTest +from sqlalchemy.testing.fixtures import TestBase, TablesTest, config from ydb._grpc.v4.protos import ydb_common_pb2 from ydb_sqlalchemy import dbapi, IsolationLevel @@ -220,8 +220,9 @@ def _create_table_and_get_desc(connection, metadata, **kwargs): ) table.create(connection) - session: ydb.Session = connection.connection.driver_connection.session + session: ydb.Session = connection.connection.driver_connection.session_pool.acquire() table_description = session.describe_table("/local/" + table.name) + connection.connection.driver_connection.session_pool.release(session) return table_description @pytest.mark.parametrize( @@ -419,11 +420,11 @@ def test_interactive_transaction( connection_no_trans.execution_options(isolation_level=isolation_level) with connection_no_trans.begin(): - tx_id = dbapi_connection.transaction.tx_id + tx_id = dbapi_connection.tx_context.tx_id assert tx_id is not None cursor1 = connection_no_trans.execute(sa.select(table)) cursor2 = connection_no_trans.execute(sa.select(table)) - assert dbapi_connection.transaction.tx_id == tx_id + assert dbapi_connection.tx_context.tx_id == tx_id assert set(cursor1.fetchall()) == {(5,), (6,)} assert set(cursor2.fetchall()) == {(5,), (6,)} @@ -448,10 +449,10 @@ def test_not_interactive_transaction( connection_no_trans.execution_options(isolation_level=isolation_level) with connection_no_trans.begin(): - assert dbapi_connection.transaction is None + assert dbapi_connection.tx_context is None cursor1 = connection_no_trans.execute(sa.select(table)) cursor2 = connection_no_trans.execute(sa.select(table)) - assert dbapi_connection.transaction is None + assert dbapi_connection.tx_context is None assert set(cursor1.fetchall()) == {(7,), (8,)} assert set(cursor2.fetchall()) == {(7,), (8,)} @@ -482,7 +483,59 @@ def test_connection_set(self, connection_no_trans: sa.Connection): assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0] assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1] if dbapi_connection.interactive_transaction: - assert dbapi_connection.transaction is not None - assert dbapi_connection.transaction.tx_id is not None + assert dbapi_connection.tx_context is not None + assert dbapi_connection.tx_context.tx_id is not None else: - assert dbapi_connection.transaction is None + assert dbapi_connection.tx_context is None + + +class TestEngine(TestBase): + @pytest.fixture(scope="module") + def ydb_driver(self): + url = config.db_url + driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database) + try: + driver.wait(timeout=5, fail_fast=True) + yield driver + finally: + driver.stop() + + driver.stop() + + @pytest.fixture(scope="module") + def ydb_pool(self, ydb_driver): + session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1) + + yield session_pool + + session_pool.stop() + + def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): + engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) + engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) + + with engine1.connect() as conn1, engine2.connect() as conn2: + dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection + dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection + + assert dbapi_conn1.session_pool is dbapi_conn2.session_pool + assert dbapi_conn1.driver is dbapi_conn2.driver + + engine1.dispose() + engine2.dispose() + assert not ydb_driver._stopped + + def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): + engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool}) + engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool}) + + with engine1.connect() as conn1, engine2.connect() as conn2: + dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection + dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection + + assert dbapi_conn1.session_pool is dbapi_conn2.session_pool + assert dbapi_conn1.driver is dbapi_conn2.driver + + engine1.dispose() + engine2.dispose() + assert not ydb_driver._stopped diff --git a/test_dbapi/conftest.py b/test_dbapi/conftest.py index 92f6610..7a9f5a3 100644 --- a/test_dbapi/conftest.py +++ b/test_dbapi/conftest.py @@ -1,9 +1,10 @@ import pytest + import ydb_sqlalchemy.dbapi as dbapi @pytest.fixture(scope="module") def connection(): - conn = dbapi.connect("localhost:2136", database="/local") + conn = dbapi.connect(host="localhost", port="2136", database="/local") yield conn conn.close() diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index 8d6ed1d..43e6273 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -1,5 +1,5 @@ import posixpath -from typing import Optional, NamedTuple +from typing import Optional, NamedTuple, Any import ydb @@ -17,23 +17,38 @@ class IsolationLevel: class Connection: - def __init__(self, endpoint=None, host=None, port=None, database=None, **conn_kwargs): - self.endpoint = endpoint or f"grpc://{host}:{port}" + def __init__( + self, + host: str = "", + port: str = "", + database: str = "", + **conn_kwargs: Any, + ): + self.endpoint = f"grpc://{host}:{port}" self.database = database - self.table_client_settings = self._get_table_client_settings() - self.driver = self._create_driver(**conn_kwargs) - self.session = self._create_session() + self.conn_kwargs = conn_kwargs + + if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually + self._shared_session_pool = True + self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool") + self.driver = self.session_pool._pool_impl._driver + self.driver.table_client = ydb.TableClient(self.driver, self._get_table_client_settings()) + else: + self._shared_session_pool = False + self.driver = self._create_driver() + self.session_pool = ydb.SessionPool(self.driver, size=5, workers_threads_count=1) + self.interactive_transaction: bool = False # AUTOCOMMIT self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() - self.transaction: Optional[ydb.TxContext] = None + self.tx_context: Optional[ydb.TxContext] = None def cursor(self): - return Cursor(self, transaction=self.transaction) + return Cursor(self.session_pool, self.tx_context) def describe(self, table_path): full_path = posixpath.join(self.database, table_path) try: - return ydb.retry_operation_sync(lambda: self.session.describe_table(full_path)) + return self.session_pool.retry_operation_sync(lambda session: session.describe_table(full_path)) except ydb.issues.SchemeError as e: raise ProgrammingError(e.message, e.issues, e.status) from e except ydb.Error as e: @@ -64,7 +79,7 @@ class IsolationSettings(NamedTuple): IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly(), interactive=True), } ydb_isolation_settings = ydb_isolation_settings_map[isolation_level] - if self.transaction and self.transaction.tx_id: + if self.tx_context and self.tx_context.tx_id: raise InternalError("Failed to set transaction mode: transaction is already began") self.tx_mode = ydb_isolation_settings.ydb_mode self.interactive_transaction = ydb_isolation_settings.interactive @@ -88,27 +103,31 @@ def get_isolation_level(self) -> str: raise NotSupportedError(f"{self.tx_mode.name} is not supported") def begin(self): - if not self.session.initialized(): - raise InternalError("Failed to begin transaction: session closed") - self.transaction = None + self.tx_context = None if self.interactive_transaction: - self.transaction = self.session.transaction(self.tx_mode) - self.transaction.begin() + session = self.session_pool.acquire(blocking=True) + self.tx_context = session.transaction(self.tx_mode) + self.tx_context.begin() def commit(self): - if self.transaction and self.transaction.tx_id: - self.transaction.commit() + if self.tx_context and self.tx_context.tx_id: + self.tx_context.commit() + self.session_pool.release(self.tx_context.session) + self.tx_context = None def rollback(self): - if self.transaction and self.transaction.tx_id: - self.transaction.rollback() + if self.tx_context and self.tx_context.tx_id: + self.tx_context.rollback() + self.session_pool.release(self.tx_context.session) + self.tx_context = None def close(self): - self._delete_session() - self._stop_driver() + self.rollback() + if not self._shared_session_pool: + self.session_pool.stop() + self._stop_driver() - @staticmethod - def _get_table_client_settings() -> ydb.TableClientSettings: + def _get_table_client_settings(self) -> ydb.TableClientSettings: return ( ydb.TableClientSettings() .with_native_date_in_result_sets(True) @@ -118,13 +137,11 @@ def _get_table_client_settings() -> ydb.TableClientSettings: .with_native_json_in_result_sets(True) ) - def _create_driver(self, **conn_kwargs): - # TODO: add cache for initialized drivers/pools? + def _create_driver(self): driver_config = ydb.DriverConfig( endpoint=self.endpoint, database=self.database, - table_client_settings=self.table_client_settings, - **conn_kwargs, + table_client_settings=self._get_table_client_settings(), ) driver = ydb.Driver(driver_config) try: @@ -138,13 +155,3 @@ def _create_driver(self, **conn_kwargs): def _stop_driver(self): self.driver.stop() - - def _create_session(self) -> ydb.BaseSession: - session = ydb.Session(self.driver, self.table_client_settings) - session.create() - return session - - def _delete_session(self): - if self.session.initialized(): - self.rollback() - self.session.delete() diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 6defdbe..4ae9565 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -1,10 +1,10 @@ import dataclasses import itertools import logging - -from typing import Any, Mapping, Optional, Sequence, Union, Dict +from typing import Any, Mapping, Optional, Sequence, Union, Dict, Callable import ydb + from .errors import ( InternalError, IntegrityError, @@ -15,7 +15,6 @@ NotSupportedError, ) - logger = logging.getLogger(__name__) @@ -33,10 +32,13 @@ class YdbQuery: class Cursor(object): - def __init__(self, connection, transaction: Optional[ydb.BaseTxContext] = None): - self.connection = connection - self.session: ydb.Session = self.connection.session - self.transaction = transaction + def __init__( + self, + session_pool: ydb.SessionPool, + tx_context: Optional[ydb.BaseTxContext] = None, + ): + self.session_pool = session_pool + self.tx_context = tx_context self.description = None self.arraysize = 1 self.rows = None @@ -50,7 +52,15 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = query = ydb.DataQuery(operation.yql_text, operation.parameters_types) is_ddl = operation.is_ddl - chunks = self._execute(query, parameters, is_ddl) + logger.info("execute sql: %s, params: %s", query, parameters) + if is_ddl: + chunks = self.session_pool.retry_operation_sync(self._execute_ddl, None, query) + else: + if self.tx_context: + chunks = self._execute_dml(self.tx_context.session, query, parameters, self.tx_context) + else: + chunks = self.session_pool.retry_operation_sync(self._execute_dml, None, query, parameters) + rows = self._rows_iterable(chunks) # Prefetch the description: try: @@ -64,23 +74,31 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = self.rows = rows - def _execute(self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]], is_ddl: bool): - self.description = None - logger.info("execute sql: %s, params: %s", query, parameters) + @classmethod + def _execute_dml( + cls, + session: ydb.Session, + query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]] = None, + tx_context: Optional[ydb.BaseTxContext] = None, + ) -> ydb.convert.ResultSets: + prepared_query = query + if isinstance(query, str) and parameters: + prepared_query = session.prepare(query) + + if tx_context: + return cls._handle_ydb_errors(tx_context.execute, prepared_query, parameters) + + return cls._handle_ydb_errors(session.transaction().execute, prepared_query, parameters, commit_tx=True) + + @classmethod + def _execute_ddl(cls, session: ydb.Session, query: str) -> ydb.convert.ResultSets: + return cls._handle_ydb_errors(session.execute_scheme, query) + + @staticmethod + def _handle_ydb_errors(callee: Callable, *args, **kwargs) -> Any: try: - if is_ddl: - return ydb.retry_operation_sync(lambda: self.session.execute_scheme(query)) - - prepared_query = query - if isinstance(query, str) and parameters: - prepared_query = self.session.prepare(query) - - if not self.transaction: - return ydb.retry_operation_sync( - lambda: self.session.transaction().execute(prepared_query, parameters, commit_tx=True) - ) - else: - return self.transaction.execute(prepared_query, parameters) + return callee(*args, **kwargs) except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: raise IntegrityError(e.message, e.issues, e.status) from e except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: @@ -108,7 +126,7 @@ def _execute(self, query: Union[ydb.DataQuery, str], parameters: Optional[Mappin except ydb.Error as e: raise DatabaseError(e.message, e.issues, e.status) from e - def _rows_iterable(self, chunks_iterable): + def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): try: for chunk in chunks_iterable: self.description = [