Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Ilchuk committed Jan 10, 2024
1 parent 7df2cb1 commit 7c7a6d3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
32 changes: 25 additions & 7 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,37 @@ def define_tables(cls, metadata):
Column("val", Integer),
)

def test_1(self, connection):
def test_string(self, connection):
tb = self.tables.test_upsert

stm = ydb_sa.upsert(tb).values(id=5, val=5)

# TODO: allow to get string
# assert "UPSERT INTO" in str(stm)
assert "UPSERT INTO" in str(stm)

def test_upsert_new_id(self, connection):
tb = self.tables.test_upsert

stm = ydb_sa.upsert(tb).values(id=1, val=1)

connection.execute(stm)
row = connection.execute(sa.select(tb)).fetchone()
assert row == (5, 5)
row = connection.execute(sa.select(tb)).fetchall()
assert row == [(1, 1)]

stm = ydb_sa.upsert(tb).values(id=2, val=2)
connection.execute(stm)
row = connection.execute(sa.select(tb)).fetchall()
assert row == [(1, 1), (2, 2)]

def test_upsert_existing_id(self, connection):
tb = self.tables.test_upsert

stm = ydb_sa.upsert(tb).values(id=5, val=5)

connection.execute(stm)
row = connection.execute(sa.select(tb)).fetchall()
assert row == [(5, 5)]

stm = ydb_sa.upsert(tb).values(id=5, val=6)
connection.execute(stm)
row = connection.execute(sa.select(tb)).fetchone()
assert row == (5, 6)
row = connection.execute(sa.select(tb)).fetchall()
assert row == [(5, 6)]
4 changes: 3 additions & 1 deletion ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,15 @@ def get_bind_types(

return parameter_types

def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw):
return self.visit_insert(insert_stmt, visited_bindparam, **kw).replace("INSERT", "UPSERT")


class YqlDDLCompiler(DDLCompiler):
pass


def upsert(table):
# return sa.sql.Insert(table)
return Upsert(table)


Expand Down
11 changes: 2 additions & 9 deletions ydb_sqlalchemy/sqlalchemy/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,9 @@
class Upsert(sa.sql.Insert):
__visit_name__ = "upsert"
_propagate_attrs = {"compile_state_plugin": "yql"}

def compile(
self,
bind=None,
dialect=None,
**kw: Any,
):
return super(Upsert, self).compile(bind, **kw)
stringify_dialect = "yql"


@sa.sql.base.CompileState.plugin_for("yql", "upsert")
class InsertDMLState(sa.sql.dml.InsertDMLState):
class UpsertDMLState(sa.sql.dml.InsertDMLState):
pass

0 comments on commit 7c7a6d3

Please sign in to comment.