From 09bb06cec1fc5e1a5f0cc7a0b99b373d64753cf5 Mon Sep 17 00:00:00 2001 From: Michael Kryukov Date: Sun, 28 Jul 2024 15:52:56 +0300 Subject: [PATCH] fix: support ExpressionField in sorting operations; fixes #55 --- mongomock_motor/patches.py | 50 +++++++++++++++++++++++++++++++++----- tests/test_beanie.py | 11 +++++++++ tests/test_mocking.py | 6 ++--- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/mongomock_motor/patches.py b/mongomock_motor/patches.py index 0a1d41b..9254ae1 100644 --- a/mongomock_motor/patches.py +++ b/mongomock_motor/patches.py @@ -75,6 +75,9 @@ def _normalize_strings(obj): if isinstance(obj, list): return [_normalize_strings(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_normalize_strings(v) for v in obj) + if isinstance(obj, dict): return {_normalize_strings(k): _normalize_strings(v) for k, v in obj.items()} @@ -85,7 +88,41 @@ def _normalize_strings(obj): return obj -def _patch_iter_documents(collection): +def _patch_iter_documents_and_get_dataset(collection): + """ + When using beanie or other solutions that utilize classes inheriting from + the "str" type, we need to explicitly transform these instances to plain + strings in cases where internal workings of "mongomock" unable to handle + custom string-like classes. Currently only beanie's "ExpressionField" is + transformed to plain strings. + """ + + def _iter_documents_with_normalized_strings(fn): + @wraps(fn) + def wrapper(filter): + return fn(_normalize_strings(filter)) + + return wrapper + + collection._iter_documents = _iter_documents_with_normalized_strings( + collection._iter_documents, + ) + + def _get_dataset_with_normalized_strings(fn): + @wraps(fn) + def wrapper(spec, sort, fields, as_class): + return fn(spec, _normalize_strings(sort), fields, as_class) + + return wrapper + + collection._get_dataset = _get_dataset_with_normalized_strings( + collection._get_dataset, + ) + + return collection + + +def _patch_get_dataset(collection): """ When using beanie, keys can have "ExpressionField" type, that is inherited from "str". Looks like pymongo works ok @@ -94,13 +131,14 @@ def _patch_iter_documents(collection): def with_normalized_strings_in_filter(fn): @wraps(fn) - def wrapper(filter): - return fn(_normalize_strings(filter)) + def wrapper(spec, sort, fields, as_class): + print(sort) + return fn(spec, _normalize_strings(sort), fields, as_class) return wrapper - collection._iter_documents = with_normalized_strings_in_filter( - collection._iter_documents, + collection._get_dataset = with_normalized_strings_in_filter( + collection._get_dataset, ) return collection @@ -110,7 +148,7 @@ def _patch_collection_internals(collection): if getattr(collection, '_patched_by_mongomock_motor', False): return collection collection = _patch_insert_and_ensure_uniques(collection) - collection = _patch_iter_documents(collection) + collection = _patch_iter_documents_and_get_dataset(collection) collection._patched_by_mongomock_motor = True return collection diff --git a/tests/test_beanie.py b/tests/test_beanie.py index deedf08..03a26d9 100644 --- a/tests/test_beanie.py +++ b/tests/test_beanie.py @@ -75,3 +75,14 @@ async def test_beanie_links(): house = houses[0] await house.fetch_all_links() assert house.door.height == 2.1 + + +@pytest.mark.anyio +async def test_beanie_sort(): + client = AsyncMongoMockClient() + await init_beanie(database=client.beanie_test, document_models=[Door]) + + await Door.insert_many([Door(width=width) for width in [4, 2, 3, 1]]) + + doors = await Door.find().sort(Door.width).to_list() + assert [door.width for door in doors] == [1, 2, 3, 4] diff --git a/tests/test_mocking.py b/tests/test_mocking.py index 102ee80..d043f43 100644 --- a/tests/test_mocking.py +++ b/tests/test_mocking.py @@ -6,7 +6,7 @@ from pymongo.results import UpdateResult from mongomock_motor import AsyncMongoMockClient -from mongomock_motor.patches import _patch_iter_documents +from mongomock_motor.patches import _patch_iter_documents_and_get_dataset @pytest.mark.anyio @@ -56,8 +56,8 @@ async def test_no_multiple_patching(): database = AsyncMongoMockClient()['test'] with patch( - 'mongomock_motor.patches._patch_iter_documents', - wraps=_patch_iter_documents, + 'mongomock_motor.patches._patch_iter_documents_and_get_dataset', + wraps=_patch_iter_documents_and_get_dataset, ) as patch_iter_documents: for _ in range(2): collection = database['test']