Skip to content

Commit

Permalink
add upsert support
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Ilchuk committed Jan 9, 2024
1 parent d413ef3 commit 7df2cb1
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
30 changes: 30 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from datetime import date, datetime

from ydb_sqlalchemy import sqlalchemy as ydb_sa


def clear_sql(stm):
return stm.replace("\n", " ").replace(" ", " ").strip()
Expand Down Expand Up @@ -200,3 +202,31 @@ def test_select_types(self, connection):

row = connection.execute(sa.select(tb)).fetchone()
assert row == (1, "Hello World!", 3.5, True, now, today)


class TestUpsert(TablesTest):
@classmethod
def define_tables(cls, metadata):
Table(
"test_upsert",
metadata,
Column("id", Integer, primary_key=True),
Column("val", Integer),
)

def test_1(self, connection):
tb = self.tables.test_upsert

stm = ydb_sa.upsert(tb).values(id=5, val=5)

# TODO: allow to get string
# assert "UPSERT INTO" in str(stm)

connection.execute(stm)
row = connection.execute(sa.select(tb)).fetchone()
assert row == (5, 5)

stm = ydb_sa.upsert(tb).values(id=5, val=6)
connection.execute(stm)
row = connection.execute(sa.select(tb)).fetchone()
assert row == (5, 6)
4 changes: 3 additions & 1 deletion ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ydb
import ydb_sqlalchemy.dbapi as dbapi
from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS
from ydb_sqlalchemy.sqlalchemy.dml import Upsert

import sqlalchemy as sa
from sqlalchemy.exc import CompileError, NoSuchTableError
Expand Down Expand Up @@ -347,7 +348,8 @@ class YqlDDLCompiler(DDLCompiler):


def upsert(table):
return sa.sql.Insert(table)
# return sa.sql.Insert(table)
return Upsert(table)


COLUMN_TYPES = {
Expand Down
23 changes: 23 additions & 0 deletions ydb_sqlalchemy/sqlalchemy/dml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any
from typing import Optional
from typing import Union

import sqlalchemy as sa


class Upsert(sa.sql.Insert):
__visit_name__ = "upsert"
_propagate_attrs = {"compile_state_plugin": "yql"}

def compile(
self,
bind=None,
dialect=None,
**kw: Any,
):
return super(Upsert, self).compile(bind, **kw)


@sa.sql.base.CompileState.plugin_for("yql", "upsert")
class InsertDMLState(sa.sql.dml.InsertDMLState):
pass

0 comments on commit 7df2cb1

Please sign in to comment.