Skip to content

Commit

Permalink
fix: support ExpressionField in sorting operations; fixes #55
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Kryukov authored and michaelkryukov committed Jul 30, 2024
1 parent 98873f3 commit 09bb06c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 9 deletions.
50 changes: 44 additions & 6 deletions mongomock_motor/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _normalize_strings(obj):
if isinstance(obj, list):
return [_normalize_strings(v) for v in obj]

if isinstance(obj, tuple):
return tuple(_normalize_strings(v) for v in obj)

if isinstance(obj, dict):
return {_normalize_strings(k): _normalize_strings(v) for k, v in obj.items()}

Expand All @@ -85,7 +88,41 @@ def _normalize_strings(obj):
return obj


def _patch_iter_documents(collection):
def _patch_iter_documents_and_get_dataset(collection):
"""
When using beanie or other solutions that utilize classes inheriting from
the "str" type, we need to explicitly transform these instances to plain
strings in cases where internal workings of "mongomock" unable to handle
custom string-like classes. Currently only beanie's "ExpressionField" is
transformed to plain strings.
"""

def _iter_documents_with_normalized_strings(fn):
@wraps(fn)
def wrapper(filter):
return fn(_normalize_strings(filter))

return wrapper

collection._iter_documents = _iter_documents_with_normalized_strings(
collection._iter_documents,
)

def _get_dataset_with_normalized_strings(fn):
@wraps(fn)
def wrapper(spec, sort, fields, as_class):
return fn(spec, _normalize_strings(sort), fields, as_class)

return wrapper

collection._get_dataset = _get_dataset_with_normalized_strings(
collection._get_dataset,
)

return collection


def _patch_get_dataset(collection):
"""
When using beanie, keys can have "ExpressionField" type,
that is inherited from "str". Looks like pymongo works ok
Expand All @@ -94,13 +131,14 @@ def _patch_iter_documents(collection):

def with_normalized_strings_in_filter(fn):
@wraps(fn)
def wrapper(filter):
return fn(_normalize_strings(filter))
def wrapper(spec, sort, fields, as_class):
print(sort)
return fn(spec, _normalize_strings(sort), fields, as_class)

return wrapper

collection._iter_documents = with_normalized_strings_in_filter(
collection._iter_documents,
collection._get_dataset = with_normalized_strings_in_filter(
collection._get_dataset,
)

return collection
Expand All @@ -110,7 +148,7 @@ def _patch_collection_internals(collection):
if getattr(collection, '_patched_by_mongomock_motor', False):
return collection
collection = _patch_insert_and_ensure_uniques(collection)
collection = _patch_iter_documents(collection)
collection = _patch_iter_documents_and_get_dataset(collection)
collection._patched_by_mongomock_motor = True
return collection

Expand Down
11 changes: 11 additions & 0 deletions tests/test_beanie.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ async def test_beanie_links():
house = houses[0]
await house.fetch_all_links()
assert house.door.height == 2.1


@pytest.mark.anyio
async def test_beanie_sort():
client = AsyncMongoMockClient()
await init_beanie(database=client.beanie_test, document_models=[Door])

await Door.insert_many([Door(width=width) for width in [4, 2, 3, 1]])

doors = await Door.find().sort(Door.width).to_list()
assert [door.width for door in doors] == [1, 2, 3, 4]
6 changes: 3 additions & 3 deletions tests/test_mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pymongo.results import UpdateResult

from mongomock_motor import AsyncMongoMockClient
from mongomock_motor.patches import _patch_iter_documents
from mongomock_motor.patches import _patch_iter_documents_and_get_dataset


@pytest.mark.anyio
Expand Down Expand Up @@ -56,8 +56,8 @@ async def test_no_multiple_patching():
database = AsyncMongoMockClient()['test']

with patch(
'mongomock_motor.patches._patch_iter_documents',
wraps=_patch_iter_documents,
'mongomock_motor.patches._patch_iter_documents_and_get_dataset',
wraps=_patch_iter_documents_and_get_dataset,
) as patch_iter_documents:
for _ in range(2):
collection = database['test']
Expand Down

0 comments on commit 09bb06c

Please sign in to comment.