From 7df2cb1cf0d71bd3fab29d192fd0b03292f3864a Mon Sep 17 00:00:00 2001 From: Ivan Ilchuk Date: Tue, 9 Jan 2024 13:36:29 +0100 Subject: [PATCH] add upsert support --- test/test_core.py | 30 +++++++++++++++++++++++++++ ydb_sqlalchemy/sqlalchemy/__init__.py | 4 +++- ydb_sqlalchemy/sqlalchemy/dml.py | 23 ++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 ydb_sqlalchemy/sqlalchemy/dml.py diff --git a/test/test_core.py b/test/test_core.py index 52c61e1..fd7bb51 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -7,6 +7,8 @@ from datetime import date, datetime +from ydb_sqlalchemy import sqlalchemy as ydb_sa + def clear_sql(stm): return stm.replace("\n", " ").replace(" ", " ").strip() @@ -200,3 +202,31 @@ def test_select_types(self, connection): row = connection.execute(sa.select(tb)).fetchone() assert row == (1, "Hello World!", 3.5, True, now, today) + + +class TestUpsert(TablesTest): + @classmethod + def define_tables(cls, metadata): + Table( + "test_upsert", + metadata, + Column("id", Integer, primary_key=True), + Column("val", Integer), + ) + + def test_1(self, connection): + tb = self.tables.test_upsert + + stm = ydb_sa.upsert(tb).values(id=5, val=5) + + # TODO: allow to get string + # assert "UPSERT INTO" in str(stm) + + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchone() + assert row == (5, 5) + + stm = ydb_sa.upsert(tb).values(id=5, val=6) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchone() + assert row == (5, 6) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 742e356..61f293d 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -6,6 +6,7 @@ import ydb import ydb_sqlalchemy.dbapi as dbapi from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS +from ydb_sqlalchemy.sqlalchemy.dml import Upsert import sqlalchemy as sa from sqlalchemy.exc import CompileError, NoSuchTableError @@ -347,7 +348,8 @@ class YqlDDLCompiler(DDLCompiler): def upsert(table): - return sa.sql.Insert(table) + # return sa.sql.Insert(table) + return Upsert(table) COLUMN_TYPES = { diff --git a/ydb_sqlalchemy/sqlalchemy/dml.py b/ydb_sqlalchemy/sqlalchemy/dml.py new file mode 100644 index 0000000..8672c36 --- /dev/null +++ b/ydb_sqlalchemy/sqlalchemy/dml.py @@ -0,0 +1,23 @@ +from typing import Any +from typing import Optional +from typing import Union + +import sqlalchemy as sa + + +class Upsert(sa.sql.Insert): + __visit_name__ = "upsert" + _propagate_attrs = {"compile_state_plugin": "yql"} + + def compile( + self, + bind=None, + dialect=None, + **kw: Any, + ): + return super(Upsert, self).compile(bind, **kw) + + +@sa.sql.base.CompileState.plugin_for("yql", "upsert") +class InsertDMLState(sa.sql.dml.InsertDMLState): + pass