diff --git a/setup.py b/setup.py index 87e7555..1cc3fb0 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ "sqlalchemy.dialects": [ "yql.ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect", "ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect", + "yql=ydb_sqlalchemy.sqlalchemy:YqlDialect", ] }, ) diff --git a/test/conftest.py b/test/conftest.py index 615796c..0f8b014 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,7 @@ registry.register("yql.ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") registry.register("ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") +registry.register("yql", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") pytest.register_assert_rewrite("sqlalchemy.testing.assertions") from sqlalchemy.testing.plugin.pytestplugin import * # noqa: E402, F401, F403 diff --git a/test/test_core.py b/test/test_core.py index fa4e326..1ff9f15 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -3,14 +3,17 @@ from typing import NamedTuple import pytest + import sqlalchemy as sa -import ydb -from sqlalchemy import Table, Column, Integer, Unicode +from sqlalchemy import Table, Column, Integer, Unicode, String from sqlalchemy.testing.fixtures import TestBase, TablesTest, config + +import ydb from ydb._grpc.v4.protos import ydb_common_pb2 from ydb_sqlalchemy import dbapi, IsolationLevel from ydb_sqlalchemy.sqlalchemy import types +from ydb_sqlalchemy import sqlalchemy as ydb_sa def clear_sql(stm): @@ -539,3 +542,136 @@ def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): engine1.dispose() engine2.dispose() assert not ydb_driver._stopped + + +class TestUpsert(TablesTest): + @classmethod + def define_tables(cls, metadata): + Table( + "test_upsert", + metadata, + Column("id", Integer, primary_key=True), + Column("val", Integer), + ) + + def test_string(self, connection): + tb = self.tables.test_upsert + stm = ydb_sa.upsert(tb).values(id=0, val=5) + + assert str(stm) == "UPSERT INTO test_upsert (id, val) VALUES (?, ?)" + + def test_upsert_new_id(self, connection): + tb = self.tables.test_upsert + stm = ydb_sa.upsert(tb).values(id=0, val=1) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + assert row == [(0, 1)] + + stm = ydb_sa.upsert(tb).values(id=1, val=2) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + assert row == [(0, 1), (1, 2)] + + def test_upsert_existing_id(self, connection): + tb = self.tables.test_upsert + stm = ydb_sa.upsert(tb).values(id=0, val=5) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + + assert row == [(0, 5)] + + stm = ydb_sa.upsert(tb).values(id=0, val=6) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + + assert row == [(0, 6)] + + def test_upsert_several_diff_id(self, connection): + tb = self.tables.test_upsert + stm = ydb_sa.upsert(tb).values( + [ + {"id": 0, "val": 4}, + {"id": 1, "val": 5}, + {"id": 2, "val": 6}, + ] + ) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + + assert row == [(0, 4), (1, 5), (2, 6)] + + def test_upsert_several_same_id(self, connection): + tb = self.tables.test_upsert + stm = ydb_sa.upsert(tb).values( + [ + {"id": 0, "val": 4}, + {"id": 0, "val": 5}, + {"id": 0, "val": 6}, + ] + ) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + + assert row == [(0, 6)] + + def test_upsert_from_select(self, connection, metadata): + table_to_select_from = Table( + "table_to_select_from", + metadata, + Column("id", Integer, primary_key=True), + Column("val", Integer), + ) + table_to_select_from.create(connection) + stm = sa.insert(table_to_select_from).values( + [ + {"id": 100, "val": 0}, + {"id": 110, "val": 1}, + {"id": 120, "val": 2}, + {"id": 130, "val": 3}, + ] + ) + connection.execute(stm) + + tb = self.tables.test_upsert + select_stm = sa.select(table_to_select_from.c.id, table_to_select_from.c.val).where( + table_to_select_from.c.id > 115, + ) + upsert_stm = ydb_sa.upsert(tb).from_select(["id", "val"], select_stm) + connection.execute(upsert_stm) + row = connection.execute(sa.select(tb)).fetchall() + + assert row == [(120, 2), (130, 3)] + + +class TestUpsertDoesNotReplaceInsert(TablesTest): + @classmethod + def define_tables(cls, metadata): + Table( + "test_upsert_does_not_replace_insert", + metadata, + Column("id", Integer, primary_key=True), + Column("VALUE_TO_INSERT", String), + ) + + def test_string(self, connection): + tb = self.tables.test_upsert_does_not_replace_insert + + stm = ydb_sa.upsert(tb).values(id=0, VALUE_TO_INSERT="5") + + assert str(stm) == "UPSERT INTO test_upsert_does_not_replace_insert (id, `VALUE_TO_INSERT`) VALUES (?, ?)" + + def test_insert_in_name(self, connection): + tb = self.tables.test_upsert_does_not_replace_insert + stm = ydb_sa.upsert(tb).values(id=1, VALUE_TO_INSERT="5") + connection.execute(stm) + row = connection.execute(sa.select(tb).where(tb.c.id == 1)).fetchone() + + assert row == (1, "5") + + def test_insert_in_name_and_field(self, connection): + tb = self.tables.test_upsert_does_not_replace_insert + stm = ydb_sa.upsert(tb).values(id=2, VALUE_TO_INSERT="INSERT is my favourite operation") + connection.execute(stm) + row = connection.execute(sa.select(tb).where(tb.c.id == 2)).fetchone() + + assert row == (2, "INSERT is my favourite operation") diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 52d6991..ed21898 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -6,6 +6,7 @@ import ydb import ydb_sqlalchemy.dbapi as dbapi from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS +from ydb_sqlalchemy.sqlalchemy.dml import Upsert import sqlalchemy as sa from sqlalchemy.exc import CompileError, NoSuchTableError @@ -341,6 +342,9 @@ def get_bind_types( return parameter_types + def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw): + return self.visit_insert(insert_stmt, visited_bindparam, **kw).replace("INSERT", "UPSERT", 1) + class YqlDDLCompiler(DDLCompiler): def post_create_table(self, table: sa.Table) -> str: @@ -379,7 +383,7 @@ def _render_table_partitioning_settings(self, ydb_opts: Dict[str, Any]) -> List[ def upsert(table): - return sa.sql.Insert(table) + return Upsert(table) COLUMN_TYPES = { diff --git a/ydb_sqlalchemy/sqlalchemy/dml.py b/ydb_sqlalchemy/sqlalchemy/dml.py new file mode 100644 index 0000000..5abbdbb --- /dev/null +++ b/ydb_sqlalchemy/sqlalchemy/dml.py @@ -0,0 +1,12 @@ +import sqlalchemy as sa + + +class Upsert(sa.sql.Insert): + __visit_name__ = "upsert" + _propagate_attrs = {"compile_state_plugin": "yql"} + stringify_dialect = "yql" + + +@sa.sql.base.CompileState.plugin_for("yql", "upsert") +class UpsertDMLState(sa.sql.dml.InsertDMLState): + pass