Skip to content

Commit 9b92cc7

Browse files
committed
Fix no_dereferencing context manager which wasn't turning off auto-dereferencing correctly in some cases + fix tests
1 parent bfc42d0 commit 9b92cc7

File tree

5 files changed

+92
-35
lines changed

5 files changed

+92
-35
lines changed

docs/changelog.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ Development
1212
- Fix validate() not being called when inheritance is used in EmbeddedDocument and validate is overriden #2784
1313
- Add support for readPreferenceTags in connection parameters #2644
1414
- Use estimated_documents_count OR documents_count when count is called, based on the query #2529
15+
- Fix no_dereferencing context manager which wasn't turning off auto-dereferencing correctly in some cases
16+
- BREAKING CHANGE: no_dereferencing context manager no longer returns the class in __enter__
17+
as it was useless and making it look like it was returning a different class
1518

1619
Changes in 0.27.0
1720
=================

docs/guide/querying.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ data. To turn off dereferencing of the results of a query use
522522
You can also turn off all dereferencing for a fixed period by using the
523523
:class:`~mongoengine.context_managers.no_dereference` context manager::
524524

525-
with no_dereference(Post) as Post:
525+
with no_dereference(Post):
526526
post = Post.objects.first()
527527
assert(isinstance(post.author, DBRef))
528528

mongoengine/context_managers.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
from contextlib import contextmanager
23

34
from pymongo.read_concern import ReadConcern
@@ -18,6 +19,25 @@
1819
)
1920

2021

22+
thread_locals = threading.local()
23+
thread_locals.no_dereferencing_class = {}
24+
25+
26+
def no_dereferencing_active_for_class(cls):
27+
return cls in thread_locals.no_dereferencing_class
28+
29+
30+
def _register_no_dereferencing_for_class(cls):
31+
thread_locals.no_dereferencing_class.setdefault(cls, 0)
32+
thread_locals.no_dereferencing_class[cls] += 1
33+
34+
35+
def _unregister_no_dereferencing_for_class(cls):
36+
thread_locals.no_dereferencing_class[cls] -= 1
37+
if thread_locals.no_dereferencing_class[cls] == 0:
38+
thread_locals.no_dereferencing_class.pop(cls)
39+
40+
2141
class switch_db:
2242
"""switch_db alias context manager.
2343
@@ -107,7 +127,7 @@ class no_dereference:
107127
Turns off all dereferencing in Documents for the duration of the context
108128
manager::
109129
110-
with no_dereference(Group) as Group:
130+
with no_dereference(Group):
111131
Group.objects.find()
112132
"""
113133

@@ -130,15 +150,17 @@ def __init__(self, cls):
130150

131151
def __enter__(self):
132152
"""Change the objects default and _auto_dereference values."""
153+
_register_no_dereferencing_for_class(self.cls)
154+
133155
for field in self.deref_fields:
134156
self.cls._fields[field]._auto_dereference = False
135-
return self.cls
136157

137158
def __exit__(self, t, value, traceback):
138159
"""Reset the default and _auto_dereference values."""
160+
_unregister_no_dereferencing_for_class(self.cls)
161+
139162
for field in self.deref_fields:
140163
self.cls._fields[field]._auto_dereference = True
141-
return self.cls
142164

143165

144166
class no_sub_classes:

mongoengine/queryset/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mongoengine.common import _import_class
1818
from mongoengine.connection import get_db
1919
from mongoengine.context_managers import (
20+
no_dereferencing_active_for_class,
2021
set_read_write_concern,
2122
set_write_concern,
2223
switch_db,
@@ -51,9 +52,6 @@ class BaseQuerySet:
5152
providing :class:`~mongoengine.Document` objects as the results.
5253
"""
5354

54-
__dereference = False
55-
_auto_dereference = True
56-
5755
def __init__(self, document, collection):
5856
self._document = document
5957
self._collection_obj = collection
@@ -74,6 +72,9 @@ def __init__(self, document, collection):
7472
self._as_pymongo = False
7573
self._search_text = None
7674

75+
self.__dereference = False
76+
self.__auto_dereference = True
77+
7778
# If inheritance is allowed, only return instances and instances of
7879
# subclasses of the class being used
7980
if document._meta.get("allow_inheritance") is True:
@@ -795,7 +796,7 @@ def clone(self):
795796
return self._clone_into(self.__class__(self._document, self._collection_obj))
796797

797798
def _clone_into(self, new_qs):
798-
"""Copy all of the relevant properties of this queryset to
799+
"""Copy all the relevant properties of this queryset to
799800
a new queryset (which has to be an instance of
800801
:class:`~mongoengine.queryset.base.BaseQuerySet`).
801802
"""
@@ -825,7 +826,6 @@ def _clone_into(self, new_qs):
825826
"_empty",
826827
"_hint",
827828
"_collation",
828-
"_auto_dereference",
829829
"_search_text",
830830
"_max_time_ms",
831831
"_comment",
@@ -836,6 +836,8 @@ def _clone_into(self, new_qs):
836836
val = getattr(self, prop)
837837
setattr(new_qs, prop, copy.copy(val))
838838

839+
new_qs.__auto_dereference = self._BaseQuerySet__auto_dereference
840+
839841
if self._cursor_obj:
840842
new_qs._cursor_obj = self._cursor_obj.clone()
841843

@@ -1741,10 +1743,15 @@ def _dereference(self):
17411743
self.__dereference = _import_class("DeReference")()
17421744
return self.__dereference
17431745

1746+
@property
1747+
def _auto_dereference(self):
1748+
should_deref = not no_dereferencing_active_for_class(self._document)
1749+
return should_deref and self.__auto_dereference
1750+
17441751
def no_dereference(self):
17451752
"""Turn off any dereferencing for the results of this queryset."""
17461753
queryset = self.clone()
1747-
queryset._auto_dereference = False
1754+
queryset.__auto_dereference = False
17481755
return queryset
17491756

17501757
# Helper Functions

tests/test_context_managers.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
import pytest
4+
from bson import DBRef
45

56
from mongoengine import *
67
from mongoengine.connection import get_db
@@ -19,8 +20,6 @@
1920

2021
class TestContextManagers(MongoDBTestCase):
2122
def test_set_write_concern(self):
22-
connect("mongoenginetest")
23-
2423
class User(Document):
2524
name = StringField()
2625

@@ -39,8 +38,6 @@ class User(Document):
3938
assert original_write_concern.document == collection.write_concern.document
4039

4140
def test_set_read_write_concern(self):
42-
connect("mongoenginetest")
43-
4441
class User(Document):
4542
name = StringField()
4643

@@ -65,7 +62,6 @@ class User(Document):
6562
assert original_write_concern.document == collection.write_concern.document
6663

6764
def test_switch_db_context_manager(self):
68-
connect("mongoenginetest")
6965
register_connection("testdb-1", "mongoenginetest2")
7066

7167
class Group(Document):
@@ -89,7 +85,6 @@ class Group(Document):
8985
assert 1 == Group.objects.count()
9086

9187
def test_switch_collection_context_manager(self):
92-
connect("mongoenginetest")
9388
register_connection(alias="testdb-1", db="mongoenginetest2")
9489

9590
class Group(Document):
@@ -117,7 +112,6 @@ class Group(Document):
117112

118113
def test_no_dereference_context_manager_object_id(self):
119114
"""Ensure that DBRef items in ListFields aren't dereferenced."""
120-
connect("mongoenginetest")
121115

122116
class User(Document):
123117
name = StringField()
@@ -136,25 +130,57 @@ class Group(Document):
136130
user = User.objects.first()
137131
Group(ref=user, members=User.objects, generic=user).save()
138132

139-
with no_dereference(Group) as NoDeRefGroup:
140-
assert Group._fields["members"]._auto_dereference
141-
assert not NoDeRefGroup._fields["members"]._auto_dereference
133+
with no_dereference(Group):
134+
assert not Group._fields["members"]._auto_dereference
142135

143-
with no_dereference(Group) as Group:
136+
with no_dereference(Group):
144137
group = Group.objects.first()
145138
for m in group.members:
146-
assert not isinstance(m, User)
147-
assert not isinstance(group.ref, User)
148-
assert not isinstance(group.generic, User)
139+
assert isinstance(m, DBRef)
140+
assert isinstance(group.ref, DBRef)
141+
assert isinstance(group.generic, dict)
149142

143+
group = Group.objects.first()
150144
for m in group.members:
151145
assert isinstance(m, User)
152146
assert isinstance(group.ref, User)
153147
assert isinstance(group.generic, User)
154148

155-
def test_no_dereference_context_manager_dbref(self):
149+
def test_no_dereference_context_manager_nested(self):
156150
"""Ensure that DBRef items in ListFields aren't dereferenced."""
157-
connect("mongoenginetest")
151+
152+
class User(Document):
153+
name = StringField()
154+
155+
class Group(Document):
156+
ref = ReferenceField(User, dbref=False)
157+
158+
User.drop_collection()
159+
Group.drop_collection()
160+
161+
for i in range(1, 51):
162+
User(name="user %s" % i).save()
163+
164+
user = User.objects.first()
165+
Group(ref=user).save()
166+
167+
with no_dereference(Group):
168+
group = Group.objects.first()
169+
assert isinstance(group.ref, DBRef)
170+
171+
with no_dereference(Group):
172+
group = Group.objects.first()
173+
assert isinstance(group.ref, DBRef)
174+
175+
# make sure its still off here
176+
group = Group.objects.first()
177+
assert isinstance(group.ref, DBRef)
178+
179+
group = Group.objects.first()
180+
assert isinstance(group.ref, User)
181+
182+
def test_no_dereference_context_manager_dbref(self):
183+
"""Ensure that DBRef items in ListFields aren't dereferenced"""
158184

159185
class User(Document):
160186
name = StringField()
@@ -173,16 +199,19 @@ class Group(Document):
173199
user = User.objects.first()
174200
Group(ref=user, members=User.objects, generic=user).save()
175201

176-
with no_dereference(Group) as NoDeRefGroup:
177-
assert Group._fields["members"]._auto_dereference
178-
assert not NoDeRefGroup._fields["members"]._auto_dereference
202+
with no_dereference(Group):
203+
assert not Group._fields["members"]._auto_dereference
179204

180-
with no_dereference(Group) as Group:
181-
group = Group.objects.first()
205+
with no_dereference(Group):
206+
qs = Group.objects
207+
assert qs._auto_dereference is False
208+
group = qs.first()
209+
assert not group._fields["members"]._auto_dereference
182210
assert all(not isinstance(m, User) for m in group.members)
183211
assert not isinstance(group.ref, User)
184212
assert not isinstance(group.generic, User)
185213

214+
group = Group.objects.first()
186215
assert all(isinstance(m, User) for m in group.members)
187216
assert isinstance(group.ref, User)
188217
assert isinstance(group.generic, User)
@@ -265,7 +294,6 @@ def test_query_counter_does_not_swallow_exception(self):
265294
raise TypeError()
266295

267296
def test_query_counter_temporarily_modifies_profiling_level(self):
268-
connect("mongoenginetest")
269297
db = get_db()
270298

271299
def _current_profiling_level():
@@ -290,7 +318,6 @@ def _set_profiling_level(lvl):
290318
raise
291319

292320
def test_query_counter(self):
293-
connect("mongoenginetest")
294321
db = get_db()
295322

296323
collection = db.query_counter
@@ -380,7 +407,6 @@ class B(Document):
380407
assert q == 3
381408

382409
def test_query_counter_counts_getmore_queries(self):
383-
connect("mongoenginetest")
384410
db = get_db()
385411

386412
collection = db.query_counter
@@ -397,7 +423,6 @@ def test_query_counter_counts_getmore_queries(self):
397423
assert q == 2 # 1st select + 1 getmore
398424

399425
def test_query_counter_ignores_particular_queries(self):
400-
connect("mongoenginetest")
401426
db = get_db()
402427

403428
collection = db.query_counter

0 commit comments

Comments
 (0)