Skip to content

Commit

Permalink
Merge pull request #2788 from bagerard/fix_no_dereference_swallowing_…
Browse files Browse the repository at this point in the history
…errors

Various fix for no_dereference context manager
  • Loading branch information
bagerard authored Dec 23, 2023
2 parents bfc42d0 + a2bb72b commit cfb4265
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 107 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Development
- Fix validate() not being called when inheritance is used in EmbeddedDocument and validate is overriden #2784
- Add support for readPreferenceTags in connection parameters #2644
- Use estimated_documents_count OR documents_count when count is called, based on the query #2529
- Fix no_dereference context manager which wasn't turning off auto-dereferencing correctly in some cases #2788
- BREAKING CHANGE: no_dereference context manager no longer returns the class in __enter__ #2788
as it was useless and making it look like it was returning a different class although it was the same.
Thus, it must be called like `with no_dereference(User):` and no longer `with no_dereference(User) as ...:`

Changes in 0.27.0
=================
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/querying.rst
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ data. To turn off dereferencing of the results of a query use
You can also turn off all dereferencing for a fixed period by using the
:class:`~mongoengine.context_managers.no_dereference` context manager::

with no_dereference(Post) as Post:
with no_dereference(Post):
post = Post.objects.first()
assert(isinstance(post.author, DBRef))

Expand Down
28 changes: 25 additions & 3 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from contextlib import contextmanager

from pymongo.read_concern import ReadConcern
Expand All @@ -18,6 +19,25 @@
)


thread_locals = threading.local()
thread_locals.no_dereferencing_class = {}


def no_dereferencing_active_for_class(cls):
return cls in thread_locals.no_dereferencing_class


def _register_no_dereferencing_for_class(cls):
thread_locals.no_dereferencing_class.setdefault(cls, 0)
thread_locals.no_dereferencing_class[cls] += 1


def _unregister_no_dereferencing_for_class(cls):
thread_locals.no_dereferencing_class[cls] -= 1
if thread_locals.no_dereferencing_class[cls] == 0:
thread_locals.no_dereferencing_class.pop(cls)


class switch_db:
"""switch_db alias context manager.
Expand Down Expand Up @@ -107,7 +127,7 @@ class no_dereference:
Turns off all dereferencing in Documents for the duration of the context
manager::
with no_dereference(Group) as Group:
with no_dereference(Group):
Group.objects.find()
"""

Expand All @@ -130,15 +150,17 @@ def __init__(self, cls):

def __enter__(self):
"""Change the objects default and _auto_dereference values."""
_register_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values."""
_unregister_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
return self.cls


class no_sub_classes:
Expand Down
19 changes: 13 additions & 6 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mongoengine.common import _import_class
from mongoengine.connection import get_db
from mongoengine.context_managers import (
no_dereferencing_active_for_class,
set_read_write_concern,
set_write_concern,
switch_db,
Expand Down Expand Up @@ -51,9 +52,6 @@ class BaseQuerySet:
providing :class:`~mongoengine.Document` objects as the results.
"""

__dereference = False
_auto_dereference = True

def __init__(self, document, collection):
self._document = document
self._collection_obj = collection
Expand All @@ -74,6 +72,9 @@ def __init__(self, document, collection):
self._as_pymongo = False
self._search_text = None

self.__dereference = False
self.__auto_dereference = True

# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
if document._meta.get("allow_inheritance") is True:
Expand Down Expand Up @@ -795,7 +796,7 @@ def clone(self):
return self._clone_into(self.__class__(self._document, self._collection_obj))

def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
"""Copy all the relevant properties of this queryset to
a new queryset (which has to be an instance of
:class:`~mongoengine.queryset.base.BaseQuerySet`).
"""
Expand Down Expand Up @@ -825,7 +826,6 @@ def _clone_into(self, new_qs):
"_empty",
"_hint",
"_collation",
"_auto_dereference",
"_search_text",
"_max_time_ms",
"_comment",
Expand All @@ -836,6 +836,8 @@ def _clone_into(self, new_qs):
val = getattr(self, prop)
setattr(new_qs, prop, copy.copy(val))

new_qs.__auto_dereference = self._BaseQuerySet__auto_dereference

if self._cursor_obj:
new_qs._cursor_obj = self._cursor_obj.clone()

Expand Down Expand Up @@ -1741,10 +1743,15 @@ def _dereference(self):
self.__dereference = _import_class("DeReference")()
return self.__dereference

@property
def _auto_dereference(self):
should_deref = not no_dereferencing_active_for_class(self._document)
return should_deref and self.__auto_dereference

def no_dereference(self):
"""Turn off any dereferencing for the results of this queryset."""
queryset = self.clone()
queryset._auto_dereference = False
queryset.__auto_dereference = False
return queryset

# Helper Functions
Expand Down
83 changes: 11 additions & 72 deletions tests/document/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from mongoengine.connection import get_db
from mongoengine.mongodb_support import (
MONGODB_42,
MONGODB_70,
get_mongodb_version,
)
from mongoengine.pymongo_support import PYMONGO_VERSION
Expand Down Expand Up @@ -451,89 +450,29 @@ class Test(Document):
# the documents returned might have more keys in that here.
query_plan = Test.objects(id=obj.id).exclude("a").explain()
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("inputStage")
.get("stage")
== "IDHACK"
query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK"
)

query_plan = Test.objects(id=obj.id).only("id").explain()
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("inputStage")
.get("stage")
== "IDHACK"
query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK"
)

mongo_db = get_mongodb_version()
query_plan = Test.objects(a=1).only("a").exclude("id").explain()
if mongo_db < MONGODB_70:
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("inputStage")
.get("stage")
== "IXSCAN"
)
else:
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("queryPlan")
.get("inputStage")
.get("stage")
== "IXSCAN"
)
assert (
query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN"
)

PROJECTION_STR = "PROJECTION" if mongo_db < MONGODB_42 else "PROJECTION_COVERED"
if mongo_db < MONGODB_70:
assert (
query_plan.get("queryPlanner").get("winningPlan").get("stage")
== PROJECTION_STR
)
else:
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("queryPlan")
.get("stage")
== PROJECTION_STR
)
assert query_plan["queryPlanner"]["winningPlan"]["stage"] == PROJECTION_STR

query_plan = Test.objects(a=1).explain()
if mongo_db < MONGODB_70:
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("inputStage")
.get("stage")
== "IXSCAN"
)
else:
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("queryPlan")
.get("inputStage")
.get("stage")
== "IXSCAN"
)

if mongo_db < MONGODB_70:
assert (
query_plan.get("queryPlanner").get("winningPlan").get("stage")
== "FETCH"
)
else:
assert (
query_plan.get("queryPlanner")
.get("winningPlan")
.get("queryPlan")
.get("stage")
== "FETCH"
)
assert (
query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN"
)

assert query_plan.get("queryPlanner").get("winningPlan").get("stage") == "FETCH"

def test_index_on_id(self):
class BlogPost(Document):
Expand Down
Loading

0 comments on commit cfb4265

Please sign in to comment.