Skip to content

Commit a9a9271

Browse files
committed
Improve weakref.proxy implementation for embedded documents
1 parent 8c364e9 commit a9a9271

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

mongoengine/base/datastructures.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(self, list_items, instance, name):
117117
self._instance = instance
118118
else:
119119
self._instance = weakref.proxy(instance)
120+
120121
self._name = name
121122
super().__init__(list_items)
122123

@@ -189,13 +190,6 @@ def _mark_as_changed(self, key=None):
189190

190191

191192
class EmbeddedDocumentList(BaseList):
192-
def __init__(self, list_items, instance, name):
193-
super().__init__(list_items, instance, name)
194-
if isinstance(instance, weakref.ProxyTypes):
195-
self._instance = instance
196-
else:
197-
self._instance = weakref.proxy(instance)
198-
199193
@classmethod
200194
def __match_all(cls, embedded_doc, kwargs):
201195
"""Return True if a given embedded doc matches all the filter

tests/fields/test_embedded_document_field.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import weakref
12
from copy import deepcopy
23

34
import pytest
@@ -62,6 +63,21 @@ class MyFailingDoc(Document):
6263
class MyFailingdoc2(Document):
6364
emb = EmbeddedDocumentField("MyDoc")
6465

66+
def test_embedded_document_field_has_a_weakref__instance_reference(self):
67+
class Wallet(EmbeddedDocument):
68+
money = IntField()
69+
70+
class WalletOwner(Document):
71+
name = StringField()
72+
wallet = EmbeddedDocumentField(Wallet)
73+
74+
WalletOwner.drop_collection()
75+
76+
wallet = Wallet(money=100)
77+
owner = WalletOwner(name="John", wallet=wallet)
78+
assert wallet._instance is owner
79+
assert isinstance(wallet._instance, weakref.ProxyTypes)
80+
6581
def test_embedded_document_field_validate_subclass(self):
6682
class BaseItem(EmbeddedDocument):
6783
f = IntField()

0 commit comments

Comments
 (0)