diff --git a/tests/fixtures/db_connection.py b/tests/fixtures/db_connection.py index f0057467..70eb353e 100644 --- a/tests/fixtures/db_connection.py +++ b/tests/fixtures/db_connection.py @@ -51,5 +51,6 @@ async def async_session_plain(async_engine): @async_fixture(scope="class") async def async_session(async_session_plain): - async with async_session_plain() as session: + async with async_session_plain() as session: # type: AsyncSession yield session + await session.rollback() diff --git a/tests/models.py b/tests/models.py index 51b73b05..5eaea38e 100644 --- a/tests/models.py +++ b/tests/models.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import declared_attr, relationship from sqlalchemy.types import CHAR, TypeDecorator -from tests.common import sqla_uri +from tests.common import is_postgres_tests, sqla_uri class Base: @@ -271,9 +271,9 @@ def python_type(self): db_uri = sqla_uri() -if "postgres" in db_uri: +if is_postgres_tests(): # noinspection PyPep8Naming - from sqlalchemy.dialects.postgresql import UUID as UUIDType + from sqlalchemy.dialects.postgresql.asyncpg import AsyncpgUUID as UUIDType elif "sqlite" in db_uri: UUIDType = CustomUUIDType else: @@ -283,10 +283,10 @@ def python_type(self): class CustomUUIDItem(Base): __tablename__ = "custom_uuid_item" - id = Column(UUIDType, primary_key=True) + id = Column(UUIDType(as_uuid=True), primary_key=True) extra_id = Column( - UUIDType, + UUIDType(as_uuid=True), nullable=True, unique=True, ) diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index af609bd9..0ad62e83 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -2380,6 +2380,36 @@ async def test_join_by_relationships_does_not_duplicating_response_entities( "meta": {"count": 1, "totalPages": 1}, } + async def test_sqla_filters_by_uuid_type( + self, + async_session: AsyncSession, + ): + """ + This test checks if UUID fields allow filtering by UUID object + + Make sure your UUID field allows native UUID filtering: `UUID(as_uuid=True)` + + :param async_session: + :return: + """ + new_id = uuid4() + extra_id = uuid4() + item = CustomUUIDItem( + id=new_id, + extra_id=extra_id, + ) + async_session.add(item) + await async_session.commit() + + # noinspection PyTypeChecker + stmt = select(CustomUUIDItem) + # works because we set `as_uuid=True` + i = await async_session.scalar(stmt.where(CustomUUIDItem.id == new_id)) + assert i + # works because we set `as_uuid=True` + i = await async_session.scalar(stmt.where(CustomUUIDItem.extra_id == extra_id)) + assert i + @pytest.mark.parametrize("filter_kind", ["small", "full"]) async def test_filter_by_field_of_uuid_type( self,