Skip to content

Commit

Permalink
Use small session pool in each connection + shared support
Browse files Browse the repository at this point in the history
  • Loading branch information
tretyak-rd committed Jan 15, 2024
1 parent a0294fc commit a85232c
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 72 deletions.
71 changes: 62 additions & 9 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sqlalchemy as sa
import ydb
from sqlalchemy import Table, Column, Integer, Unicode
from sqlalchemy.testing.fixtures import TestBase, TablesTest
from sqlalchemy.testing.fixtures import TestBase, TablesTest, config
from ydb._grpc.v4.protos import ydb_common_pb2

from ydb_sqlalchemy import dbapi, IsolationLevel
Expand Down Expand Up @@ -220,8 +220,9 @@ def _create_table_and_get_desc(connection, metadata, **kwargs):
)
table.create(connection)

session: ydb.Session = connection.connection.driver_connection.session
session: ydb.Session = connection.connection.driver_connection.session_pool.acquire()
table_description = session.describe_table("/local/" + table.name)
connection.connection.driver_connection.session_pool.release(session)
return table_description

@pytest.mark.parametrize(
Expand Down Expand Up @@ -419,11 +420,11 @@ def test_interactive_transaction(

connection_no_trans.execution_options(isolation_level=isolation_level)
with connection_no_trans.begin():
tx_id = dbapi_connection.transaction.tx_id
tx_id = dbapi_connection.tx_context.tx_id
assert tx_id is not None
cursor1 = connection_no_trans.execute(sa.select(table))
cursor2 = connection_no_trans.execute(sa.select(table))
assert dbapi_connection.transaction.tx_id == tx_id
assert dbapi_connection.tx_context.tx_id == tx_id

assert set(cursor1.fetchall()) == {(5,), (6,)}
assert set(cursor2.fetchall()) == {(5,), (6,)}
Expand All @@ -448,10 +449,10 @@ def test_not_interactive_transaction(

connection_no_trans.execution_options(isolation_level=isolation_level)
with connection_no_trans.begin():
assert dbapi_connection.transaction is None
assert dbapi_connection.tx_context is None
cursor1 = connection_no_trans.execute(sa.select(table))
cursor2 = connection_no_trans.execute(sa.select(table))
assert dbapi_connection.transaction is None
assert dbapi_connection.tx_context is None

assert set(cursor1.fetchall()) == {(7,), (8,)}
assert set(cursor2.fetchall()) == {(7,), (8,)}
Expand Down Expand Up @@ -482,7 +483,59 @@ def test_connection_set(self, connection_no_trans: sa.Connection):
assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0]
assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1]
if dbapi_connection.interactive_transaction:
assert dbapi_connection.transaction is not None
assert dbapi_connection.transaction.tx_id is not None
assert dbapi_connection.tx_context is not None
assert dbapi_connection.tx_context.tx_id is not None
else:
assert dbapi_connection.transaction is None
assert dbapi_connection.tx_context is None


class TestEngine(TestBase):
@pytest.fixture(scope="module")
def ydb_driver(self):
url = config.db_url
driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database)
try:
driver.wait(timeout=5, fail_fast=True)
yield driver
finally:
driver.stop()

driver.stop()

@pytest.fixture(scope="module")
def ydb_pool(self, ydb_driver):
session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1)

yield session_pool

session_pool.stop()

def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver

engine1.dispose()
engine2.dispose()
assert not ydb_driver._stopped

def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver

engine1.dispose()
engine2.dispose()
assert not ydb_driver._stopped
3 changes: 2 additions & 1 deletion test_dbapi/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest

import ydb_sqlalchemy.dbapi as dbapi


@pytest.fixture(scope="module")
def connection():
conn = dbapi.connect("localhost:2136", database="/local")
conn = dbapi.connect(host="localhost", port="2136", database="/local")
yield conn
conn.close()
81 changes: 44 additions & 37 deletions ydb_sqlalchemy/dbapi/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import posixpath
from typing import Optional, NamedTuple
from typing import Optional, NamedTuple, Any

import ydb

Expand All @@ -17,23 +17,38 @@ class IsolationLevel:


class Connection:
def __init__(self, endpoint=None, host=None, port=None, database=None, **conn_kwargs):
self.endpoint = endpoint or f"grpc://{host}:{port}"
def __init__(
self,
host: str = "",
port: str = "",
database: str = "",
**conn_kwargs: Any,
):
self.endpoint = f"grpc://{host}:{port}"
self.database = database
self.table_client_settings = self._get_table_client_settings()
self.driver = self._create_driver(**conn_kwargs)
self.session = self._create_session()
self.conn_kwargs = conn_kwargs

if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually
self._shared_session_pool = True
self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool")
self.driver = self.session_pool._pool_impl._driver
self.driver.table_client = ydb.TableClient(self.driver, self._get_table_client_settings())
else:
self._shared_session_pool = False
self.driver = self._create_driver()
self.session_pool = ydb.SessionPool(self.driver, size=5, workers_threads_count=1)

self.interactive_transaction: bool = False # AUTOCOMMIT
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
self.transaction: Optional[ydb.TxContext] = None
self.tx_context: Optional[ydb.TxContext] = None

def cursor(self):
return Cursor(self, transaction=self.transaction)
return Cursor(self.session_pool, self.tx_context)

def describe(self, table_path):
full_path = posixpath.join(self.database, table_path)
try:
return ydb.retry_operation_sync(lambda: self.session.describe_table(full_path))
return self.session_pool.retry_operation_sync(lambda session: session.describe_table(full_path))
except ydb.issues.SchemeError as e:
raise ProgrammingError(e.message, e.issues, e.status) from e
except ydb.Error as e:
Expand Down Expand Up @@ -64,7 +79,7 @@ class IsolationSettings(NamedTuple):
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly(), interactive=True),
}
ydb_isolation_settings = ydb_isolation_settings_map[isolation_level]
if self.transaction and self.transaction.tx_id:
if self.tx_context and self.tx_context.tx_id:
raise InternalError("Failed to set transaction mode: transaction is already began")
self.tx_mode = ydb_isolation_settings.ydb_mode
self.interactive_transaction = ydb_isolation_settings.interactive
Expand All @@ -88,27 +103,31 @@ def get_isolation_level(self) -> str:
raise NotSupportedError(f"{self.tx_mode.name} is not supported")

def begin(self):
if not self.session.initialized():
raise InternalError("Failed to begin transaction: session closed")
self.transaction = None
self.tx_context = None
if self.interactive_transaction:
self.transaction = self.session.transaction(self.tx_mode)
self.transaction.begin()
session = self.session_pool.acquire(blocking=True)
self.tx_context = session.transaction(self.tx_mode)
self.tx_context.begin()

def commit(self):
if self.transaction and self.transaction.tx_id:
self.transaction.commit()
if self.tx_context and self.tx_context.tx_id:
self.tx_context.commit()
self.session_pool.release(self.tx_context.session)
self.tx_context = None

def rollback(self):
if self.transaction and self.transaction.tx_id:
self.transaction.rollback()
if self.tx_context and self.tx_context.tx_id:
self.tx_context.rollback()
self.session_pool.release(self.tx_context.session)
self.tx_context = None

def close(self):
self._delete_session()
self._stop_driver()
self.rollback()
if not self._shared_session_pool:
self.session_pool.stop()
self._stop_driver()

@staticmethod
def _get_table_client_settings() -> ydb.TableClientSettings:
def _get_table_client_settings(self) -> ydb.TableClientSettings:
return (
ydb.TableClientSettings()
.with_native_date_in_result_sets(True)
Expand All @@ -118,13 +137,11 @@ def _get_table_client_settings() -> ydb.TableClientSettings:
.with_native_json_in_result_sets(True)
)

def _create_driver(self, **conn_kwargs):
# TODO: add cache for initialized drivers/pools?
def _create_driver(self):
driver_config = ydb.DriverConfig(
endpoint=self.endpoint,
database=self.database,
table_client_settings=self.table_client_settings,
**conn_kwargs,
table_client_settings=self._get_table_client_settings(),
)
driver = ydb.Driver(driver_config)
try:
Expand All @@ -138,13 +155,3 @@ def _create_driver(self, **conn_kwargs):

def _stop_driver(self):
self.driver.stop()

def _create_session(self) -> ydb.BaseSession:
session = ydb.Session(self.driver, self.table_client_settings)
session.create()
return session

def _delete_session(self):
if self.session.initialized():
self.rollback()
self.session.delete()
68 changes: 43 additions & 25 deletions ydb_sqlalchemy/dbapi/cursor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dataclasses
import itertools
import logging

from typing import Any, Mapping, Optional, Sequence, Union, Dict
from typing import Any, Mapping, Optional, Sequence, Union, Dict, Callable

import ydb

from .errors import (
InternalError,
IntegrityError,
Expand All @@ -15,7 +15,6 @@
NotSupportedError,
)


logger = logging.getLogger(__name__)


Expand All @@ -33,10 +32,13 @@ class YdbQuery:


class Cursor(object):
def __init__(self, connection, transaction: Optional[ydb.BaseTxContext] = None):
self.connection = connection
self.session: ydb.Session = self.connection.session
self.transaction = transaction
def __init__(
self,
session_pool: ydb.SessionPool,
tx_context: Optional[ydb.BaseTxContext] = None,
):
self.session_pool = session_pool
self.tx_context = tx_context
self.description = None
self.arraysize = 1
self.rows = None
Expand All @@ -50,7 +52,15 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =
query = ydb.DataQuery(operation.yql_text, operation.parameters_types)
is_ddl = operation.is_ddl

chunks = self._execute(query, parameters, is_ddl)
logger.info("execute sql: %s, params: %s", query, parameters)
if is_ddl:
chunks = self.session_pool.retry_operation_sync(self._execute_ddl, None, query)
else:
if self.tx_context:
chunks = self._execute_dml(self.tx_context.session, query, parameters, self.tx_context)
else:
chunks = self.session_pool.retry_operation_sync(self._execute_dml, None, query, parameters)

rows = self._rows_iterable(chunks)
# Prefetch the description:
try:
Expand All @@ -64,23 +74,31 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =

self.rows = rows

def _execute(self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]], is_ddl: bool):
self.description = None
logger.info("execute sql: %s, params: %s", query, parameters)
@classmethod
def _execute_dml(
cls,
session: ydb.Session,
query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]] = None,
tx_context: Optional[ydb.BaseTxContext] = None,
) -> ydb.convert.ResultSets:
prepared_query = query
if isinstance(query, str) and parameters:
prepared_query = session.prepare(query)

if tx_context:
return cls._handle_ydb_errors(tx_context.execute, prepared_query, parameters)

return cls._handle_ydb_errors(session.transaction().execute, prepared_query, parameters, commit_tx=True)

@classmethod
def _execute_ddl(cls, session: ydb.Session, query: str) -> ydb.convert.ResultSets:
return cls._handle_ydb_errors(session.execute_scheme, query)

@staticmethod
def _handle_ydb_errors(callee: Callable, *args, **kwargs) -> Any:
try:
if is_ddl:
return ydb.retry_operation_sync(lambda: self.session.execute_scheme(query))

prepared_query = query
if isinstance(query, str) and parameters:
prepared_query = self.session.prepare(query)

if not self.transaction:
return ydb.retry_operation_sync(
lambda: self.session.transaction().execute(prepared_query, parameters, commit_tx=True)
)
else:
return self.transaction.execute(prepared_query, parameters)
return callee(*args, **kwargs)
except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e:
raise IntegrityError(e.message, e.issues, e.status) from e
except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e:
Expand Down Expand Up @@ -108,7 +126,7 @@ def _execute(self, query: Union[ydb.DataQuery, str], parameters: Optional[Mappin
except ydb.Error as e:
raise DatabaseError(e.message, e.issues, e.status) from e

def _rows_iterable(self, chunks_iterable):
def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets):
try:
for chunk in chunks_iterable:
self.description = [
Expand Down

0 comments on commit a85232c

Please sign in to comment.