Skip to content

Commit 2bac070

Browse files
committed
Merge branch 'fix/list-enum-fields' of github.com:jaesivsm/mongoengine into clone_list_enum_fields
2 parents 30c2485 + eddce8e commit 2bac070

File tree

3 files changed

+79
-16
lines changed

3 files changed

+79
-16
lines changed

mongoengine/base/fields.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ def __set__(self, instance, value):
134134
# If setting to None and there is a default value provided for this
135135
# field, then set the value to the default value.
136136
if value is None:
137-
if self.null:
138-
value = None
139-
elif self.default is not None:
137+
if not self.null and self.default is not None:
140138
value = self.default
141139
if callable(value):
142140
value = value()
@@ -282,6 +280,19 @@ def _lazy_load_refs(instance, name, ref_values, *, max_depth):
282280
)
283281
return documents
284282

283+
def __set__(self, instance, value):
284+
unprocessable_fields = (
285+
ComplexBaseField,
286+
_import_class("EmbeddedDocumentField"),
287+
_import_class("FileField"),
288+
)
289+
if self.field and not isinstance(self.field, unprocessable_fields):
290+
if isinstance(value, (list, tuple)):
291+
value = [self.field.to_python(sub_val) for sub_val in value]
292+
elif isinstance(value, dict):
293+
value = {key: self.field.to_python(sub) for key, sub in value.items()}
294+
return super().__set__(instance, value)
295+
285296
def __get__(self, instance, owner):
286297
"""Descriptor to automatically dereference references."""
287298
if instance is None:
@@ -439,7 +450,7 @@ def to_mongo(self, value, use_db_field=True, fields=None):
439450
# us to dereference
440451
meta = getattr(v, "_meta", {})
441452
allow_inheritance = meta.get("allow_inheritance")
442-
if not allow_inheritance and not self.field:
453+
if not allow_inheritance:
443454
value_dict[k] = GenericReferenceField().to_mongo(v)
444455
else:
445456
collection = v._get_collection_name()

mongoengine/fields.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,14 +1665,25 @@ def __init__(self, enum, **kwargs):
16651665
kwargs["choices"] = list(self._enum_cls) # Implicit validator
16661666
super().__init__(**kwargs)
16671667

1668-
def __set__(self, instance, value):
1669-
is_legal_value = value is None or isinstance(value, self._enum_cls)
1670-
if not is_legal_value:
1668+
def validate(self, value):
1669+
if isinstance(value, self._enum_cls):
1670+
return super().validate(value)
1671+
try:
1672+
self._enum_cls(value)
1673+
except ValueError:
1674+
self.error(f"{value} is not a valid {self._enum_cls}")
1675+
1676+
def to_python(self, value):
1677+
value = super().to_python(value)
1678+
if not isinstance(value, self._enum_cls):
16711679
try:
1672-
value = self._enum_cls(value)
1673-
except Exception:
1674-
pass
1675-
return super().__set__(instance, value)
1680+
return self._enum_cls(value)
1681+
except ValueError:
1682+
return value
1683+
return value
1684+
1685+
def __set__(self, instance, value):
1686+
return super().__set__(instance, self.to_python(value))
16761687

16771688
def to_mongo(self, value):
16781689
if isinstance(value, self._enum_cls):

tests/fields/test_enum_field.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44
from bson import InvalidDocument
55

6-
from mongoengine import Document, EnumField, ValidationError
6+
from mongoengine import (
7+
DictField,
8+
Document,
9+
EnumField,
10+
ListField,
11+
ValidationError,
12+
)
713
from tests.utils import MongoDBTestCase, get_as_pymongo
814

915

@@ -21,6 +27,12 @@ class ModelWithEnum(Document):
2127
status = EnumField(Status)
2228

2329

30+
class ModelComplexEnum(Document):
31+
status = EnumField(Status)
32+
statuses = ListField(EnumField(Status))
33+
color_mapping = DictField(EnumField(Color))
34+
35+
2436
class TestStringEnumField(MongoDBTestCase):
2537
def test_storage(self):
2638
model = ModelWithEnum(status=Status.NEW).save()
@@ -101,6 +113,38 @@ def test_wrong_choices(self):
101113
with pytest.raises(ValueError, match="Invalid choices"):
102114
EnumField(Status, choices=[Status.DONE, Color.RED])
103115

116+
def test_embedding_in_complex_field(self):
117+
ModelComplexEnum.drop_collection()
118+
model = ModelComplexEnum(
119+
status="new", statuses=["new"], color_mapping={"red": 1}
120+
).save()
121+
assert model.status == Status.NEW
122+
assert model.statuses == [Status.NEW]
123+
assert model.color_mapping == {"red": Color.RED}
124+
model.reload()
125+
assert model.status == Status.NEW
126+
assert model.statuses == [Status.NEW]
127+
assert model.color_mapping == {"red": Color.RED}
128+
model.status = "done"
129+
model.color_mapping = {"blue": 2}
130+
model.statuses = ["new", "done"]
131+
assert model.status == Status.DONE
132+
assert model.color_mapping == {"blue": Color.BLUE}, model.color_mapping
133+
assert model.statuses == [Status.NEW, Status.DONE], model.statuses
134+
model = model.save().reload()
135+
assert model.status == Status.DONE
136+
assert model.color_mapping == {"blue": Color.BLUE}, model.color_mapping
137+
assert model.statuses == [Status.NEW, Status.DONE], model.statuses
138+
139+
with pytest.raises(ValidationError, match="must be one of ..Status"):
140+
model.statuses = [1]
141+
model.save()
142+
143+
model.statuses = ["done"]
144+
model.color_mapping = {"blue": "done"}
145+
with pytest.raises(ValidationError, match="must be one of ..Color"):
146+
model.save()
147+
104148

105149
class ModelWithColor(Document):
106150
color = EnumField(Color, default=Color.RED)
@@ -124,10 +168,7 @@ def test_storage_enum_with_int(self):
124168
assert get_as_pymongo(model) == {"_id": model.id, "color": 2}
125169

126170
def test_validate_model(self):
127-
with pytest.raises(ValidationError, match="Value must be one of"):
128-
ModelWithColor(color=3).validate()
129-
130-
with pytest.raises(ValidationError, match="Value must be one of"):
171+
with pytest.raises(ValidationError, match="must be one of ..Color"):
131172
ModelWithColor(color="wrong_type").validate()
132173

133174

0 commit comments

Comments
 (0)