Skip to content

Commit

Permalink
Merge branch 'main' into add-upsert-support
Browse files Browse the repository at this point in the history
  • Loading branch information
ilchuk96 authored Jan 16, 2024
2 parents a2fb6e3 + dc7381a commit 167c1ac
Show file tree
Hide file tree
Showing 7 changed files with 378 additions and 79 deletions.
183 changes: 178 additions & 5 deletions test/test_core.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from decimal import Decimal
from datetime import date, datetime
from decimal import Decimal
from typing import NamedTuple

import pytest

import sqlalchemy as sa
from sqlalchemy import Table, Column, Integer, Unicode, String
from sqlalchemy.testing.fixtures import TestBase, TablesTest
from sqlalchemy.testing.fixtures import TestBase, TablesTest, config

import ydb
from ydb._grpc.v4.protos import ydb_common_pb2

from ydb_sqlalchemy import dbapi, IsolationLevel
from ydb_sqlalchemy.sqlalchemy import types

from ydb_sqlalchemy import sqlalchemy as ydb_sa


Expand Down Expand Up @@ -221,9 +223,9 @@ def _create_table_and_get_desc(connection, metadata, **kwargs):
)
table.create(connection)

session: ydb.Session = connection.connection.driver_connection.pool.acquire()
session: ydb.Session = connection.connection.driver_connection.session_pool.acquire()
table_description = session.describe_table("/local/" + table.name)
session.delete()
connection.connection.driver_connection.session_pool.release(session)
return table_description

@pytest.mark.parametrize(
Expand Down Expand Up @@ -371,6 +373,177 @@ def test_several_keys(self, connection, metadata):
assert desc.partitioning_settings.max_partitions_count == 5


class TestTransaction(TablesTest):
@classmethod
def define_tables(cls, metadata: sa.MetaData):
Table(
"test",
metadata,
Column("id", Integer, primary_key=True),
)

def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
with connection_no_trans.begin():
stm1 = table.insert().values(id=1)
connection_no_trans.execute(stm1)
stm2 = table.insert().values(id=2)
connection_no_trans.execute(stm2)
connection_no_trans.rollback()

cursor = connection.execute(sa.select(table))
result = cursor.fetchall()
assert result == []

def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
with connection_no_trans.begin():
stm1 = table.insert().values(id=3)
connection_no_trans.execute(stm1)
stm2 = table.insert().values(id=4)
connection_no_trans.execute(stm2)

cursor = connection.execute(sa.select(table))
result = cursor.fetchall()
assert set(result) == {(3,), (4,)}

@pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY))
def test_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

stm1 = table.insert().values([{"id": 5}, {"id": 6}])
connection.execute(stm1)

connection_no_trans.execution_options(isolation_level=isolation_level)
with connection_no_trans.begin():
tx_id = dbapi_connection.tx_context.tx_id
assert tx_id is not None
cursor1 = connection_no_trans.execute(sa.select(table))
cursor2 = connection_no_trans.execute(sa.select(table))
assert dbapi_connection.tx_context.tx_id == tx_id

assert set(cursor1.fetchall()) == {(5,), (6,)}
assert set(cursor2.fetchall()) == {(5,), (6,)}

@pytest.mark.parametrize(
"isolation_level",
(
IsolationLevel.ONLINE_READONLY,
IsolationLevel.ONLINE_READONLY_INCONSISTENT,
IsolationLevel.STALE_READONLY,
IsolationLevel.AUTOCOMMIT,
),
)
def test_not_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

stm1 = table.insert().values([{"id": 7}, {"id": 8}])
connection.execute(stm1)

connection_no_trans.execution_options(isolation_level=isolation_level)
with connection_no_trans.begin():
assert dbapi_connection.tx_context is None
cursor1 = connection_no_trans.execute(sa.select(table))
cursor2 = connection_no_trans.execute(sa.select(table))
assert dbapi_connection.tx_context is None

assert set(cursor1.fetchall()) == {(7,), (8,)}
assert set(cursor2.fetchall()) == {(7,), (8,)}


class TestTransactionIsolationLevel(TestBase):
class IsolationSettings(NamedTuple):
ydb_mode: ydb.AbstractTransactionModeBuilder
interactive: bool

YDB_ISOLATION_SETTINGS_MAP = {
IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite().name, False),
IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite().name, True),
IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly().name, False),
IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings(
ydb.OnlineReadOnly().with_allow_inconsistent_reads().name, False
),
IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly().name, False),
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly().name, True),
}

def test_connection_set(self, connection_no_trans: sa.Connection):
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items():
connection_no_trans.execution_options(isolation_level=sa_isolation_level)
with connection_no_trans.begin():
assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0]
assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1]
if dbapi_connection.interactive_transaction:
assert dbapi_connection.tx_context is not None
assert dbapi_connection.tx_context.tx_id is not None
else:
assert dbapi_connection.tx_context is None


class TestEngine(TestBase):
@pytest.fixture(scope="module")
def ydb_driver(self):
url = config.db_url
driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database)
try:
driver.wait(timeout=5, fail_fast=True)
yield driver
finally:
driver.stop()

driver.stop()

@pytest.fixture(scope="module")
def ydb_pool(self, ydb_driver):
session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1)

yield session_pool

session_pool.stop()

def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver

engine1.dispose()
engine2.dispose()
assert not ydb_driver._stopped

def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver

engine1.dispose()
engine2.dispose()
assert not ydb_driver._stopped


class TestUpsert(TablesTest):
@classmethod
def define_tables(cls, metadata):
Expand Down
3 changes: 2 additions & 1 deletion test_dbapi/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest

import ydb_sqlalchemy.dbapi as dbapi


@pytest.fixture(scope="module")
def connection():
conn = dbapi.connect("localhost:2136", database="/local")
conn = dbapi.connect(host="localhost", port="2136", database="/local")
yield conn
conn.close()
1 change: 1 addition & 0 deletions ydb_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dbapi import IsolationLevel # noqa: F401
2 changes: 1 addition & 1 deletion ydb_sqlalchemy/dbapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .connection import Connection
from .connection import Connection, IsolationLevel # noqa: F401
from .cursor import Cursor, YdbQuery # noqa: F401
from .errors import (
Warning,
Expand Down
Loading

0 comments on commit 167c1ac

Please sign in to comment.