diff --git a/test/test_core.py b/test/test_core.py index fd7bb51..2d99a18 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -214,19 +214,37 @@ def define_tables(cls, metadata): Column("val", Integer), ) - def test_1(self, connection): + def test_string(self, connection): tb = self.tables.test_upsert stm = ydb_sa.upsert(tb).values(id=5, val=5) - # TODO: allow to get string - # assert "UPSERT INTO" in str(stm) + assert "UPSERT INTO" in str(stm) + + def test_upsert_new_id(self, connection): + tb = self.tables.test_upsert + + stm = ydb_sa.upsert(tb).values(id=1, val=1) connection.execute(stm) - row = connection.execute(sa.select(tb)).fetchone() - assert row == (5, 5) + row = connection.execute(sa.select(tb)).fetchall() + assert row == [(1, 1)] + + stm = ydb_sa.upsert(tb).values(id=2, val=2) + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + assert row == [(1, 1), (2, 2)] + + def test_upsert_existing_id(self, connection): + tb = self.tables.test_upsert + + stm = ydb_sa.upsert(tb).values(id=5, val=5) + + connection.execute(stm) + row = connection.execute(sa.select(tb)).fetchall() + assert row == [(5, 5)] stm = ydb_sa.upsert(tb).values(id=5, val=6) connection.execute(stm) - row = connection.execute(sa.select(tb)).fetchone() - assert row == (5, 6) + row = connection.execute(sa.select(tb)).fetchall() + assert row == [(5, 6)] diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 61f293d..ddc0dea 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -342,13 +342,15 @@ def get_bind_types( return parameter_types + def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw): + return self.visit_insert(insert_stmt, visited_bindparam, **kw).replace("INSERT", "UPSERT") + class YqlDDLCompiler(DDLCompiler): pass def upsert(table): - # return sa.sql.Insert(table) return Upsert(table) diff --git a/ydb_sqlalchemy/sqlalchemy/dml.py b/ydb_sqlalchemy/sqlalchemy/dml.py index 8672c36..b2eabaf 100644 --- a/ydb_sqlalchemy/sqlalchemy/dml.py +++ b/ydb_sqlalchemy/sqlalchemy/dml.py @@ -8,16 +8,9 @@ class Upsert(sa.sql.Insert): __visit_name__ = "upsert" _propagate_attrs = {"compile_state_plugin": "yql"} - - def compile( - self, - bind=None, - dialect=None, - **kw: Any, - ): - return super(Upsert, self).compile(bind, **kw) + stringify_dialect = "yql" @sa.sql.base.CompileState.plugin_for("yql", "upsert") -class InsertDMLState(sa.sql.dml.InsertDMLState): +class UpsertDMLState(sa.sql.dml.InsertDMLState): pass