Skip to content

Commit eddce8e

Browse files
committed
Fixing loading of EnumField inside ListField
Under certain condition, inside value of ListField(EnumField) don't get casted to the proper Enum. This patches fixes that.
1 parent 1790f3d commit eddce8e

File tree

3 files changed

+79
-16
lines changed

3 files changed

+79
-16
lines changed

mongoengine/base/fields.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ def __set__(self, instance, value):
134134
# If setting to None and there is a default value provided for this
135135
# field, then set the value to the default value.
136136
if value is None:
137-
if self.null:
138-
value = None
139-
elif self.default is not None:
137+
if not self.null and self.default is not None:
140138
value = self.default
141139
if callable(value):
142140
value = value()
@@ -282,6 +280,19 @@ def _lazy_load_refs(instance, name, ref_values, *, max_depth):
282280
)
283281
return documents
284282

283+
def __set__(self, instance, value):
284+
unprocessable_fields = (
285+
ComplexBaseField,
286+
_import_class("EmbeddedDocumentField"),
287+
_import_class("FileField"),
288+
)
289+
if self.field and not isinstance(self.field, unprocessable_fields):
290+
if isinstance(value, (list, tuple)):
291+
value = [self.field.to_python(sub_val) for sub_val in value]
292+
elif isinstance(value, dict):
293+
value = {key: self.field.to_python(sub) for key, sub in value.items()}
294+
return super().__set__(instance, value)
295+
285296
def __get__(self, instance, owner):
286297
"""Descriptor to automatically dereference references."""
287298
if instance is None:
@@ -439,7 +450,7 @@ def to_mongo(self, value, use_db_field=True, fields=None):
439450
# us to dereference
440451
meta = getattr(v, "_meta", {})
441452
allow_inheritance = meta.get("allow_inheritance")
442-
if not allow_inheritance and not self.field:
453+
if not allow_inheritance:
443454
value_dict[k] = GenericReferenceField().to_mongo(v)
444455
else:
445456
collection = v._get_collection_name()

mongoengine/fields.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,14 +1660,25 @@ def __init__(self, enum, **kwargs):
16601660
kwargs["choices"] = list(self._enum_cls) # Implicit validator
16611661
super().__init__(**kwargs)
16621662

1663-
def __set__(self, instance, value):
1664-
is_legal_value = value is None or isinstance(value, self._enum_cls)
1665-
if not is_legal_value:
1663+
def validate(self, value):
1664+
if isinstance(value, self._enum_cls):
1665+
return super().validate(value)
1666+
try:
1667+
self._enum_cls(value)
1668+
except ValueError:
1669+
self.error(f"{value} is not a valid {self._enum_cls}")
1670+
1671+
def to_python(self, value):
1672+
value = super().to_python(value)
1673+
if not isinstance(value, self._enum_cls):
16661674
try:
1667-
value = self._enum_cls(value)
1668-
except Exception:
1669-
pass
1670-
return super().__set__(instance, value)
1675+
return self._enum_cls(value)
1676+
except ValueError:
1677+
return value
1678+
return value
1679+
1680+
def __set__(self, instance, value):
1681+
return super().__set__(instance, self.to_python(value))
16711682

16721683
def to_mongo(self, value):
16731684
if isinstance(value, self._enum_cls):

tests/fields/test_enum_field.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44
from bson import InvalidDocument
55

6-
from mongoengine import Document, EnumField, ValidationError
6+
from mongoengine import (
7+
DictField,
8+
Document,
9+
EnumField,
10+
ListField,
11+
ValidationError,
12+
)
713
from tests.utils import MongoDBTestCase, get_as_pymongo
814

915

@@ -21,6 +27,12 @@ class ModelWithEnum(Document):
2127
status = EnumField(Status)
2228

2329

30+
class ModelComplexEnum(Document):
31+
status = EnumField(Status)
32+
statuses = ListField(EnumField(Status))
33+
color_mapping = DictField(EnumField(Color))
34+
35+
2436
class TestStringEnumField(MongoDBTestCase):
2537
def test_storage(self):
2638
model = ModelWithEnum(status=Status.NEW).save()
@@ -101,6 +113,38 @@ def test_wrong_choices(self):
101113
with pytest.raises(ValueError, match="Invalid choices"):
102114
EnumField(Status, choices=[Status.DONE, Color.RED])
103115

116+
def test_embedding_in_complex_field(self):
117+
ModelComplexEnum.drop_collection()
118+
model = ModelComplexEnum(
119+
status="new", statuses=["new"], color_mapping={"red": 1}
120+
).save()
121+
assert model.status == Status.NEW
122+
assert model.statuses == [Status.NEW]
123+
assert model.color_mapping == {"red": Color.RED}
124+
model.reload()
125+
assert model.status == Status.NEW
126+
assert model.statuses == [Status.NEW]
127+
assert model.color_mapping == {"red": Color.RED}
128+
model.status = "done"
129+
model.color_mapping = {"blue": 2}
130+
model.statuses = ["new", "done"]
131+
assert model.status == Status.DONE
132+
assert model.color_mapping == {"blue": Color.BLUE}, model.color_mapping
133+
assert model.statuses == [Status.NEW, Status.DONE], model.statuses
134+
model = model.save().reload()
135+
assert model.status == Status.DONE
136+
assert model.color_mapping == {"blue": Color.BLUE}, model.color_mapping
137+
assert model.statuses == [Status.NEW, Status.DONE], model.statuses
138+
139+
with pytest.raises(ValidationError, match="must be one of ..Status"):
140+
model.statuses = [1]
141+
model.save()
142+
143+
model.statuses = ["done"]
144+
model.color_mapping = {"blue": "done"}
145+
with pytest.raises(ValidationError, match="must be one of ..Color"):
146+
model.save()
147+
104148

105149
class ModelWithColor(Document):
106150
color = EnumField(Color, default=Color.RED)
@@ -124,10 +168,7 @@ def test_storage_enum_with_int(self):
124168
assert get_as_pymongo(model) == {"_id": model.id, "color": 2}
125169

126170
def test_validate_model(self):
127-
with pytest.raises(ValidationError, match="Value must be one of"):
128-
ModelWithColor(color=3).validate()
129-
130-
with pytest.raises(ValidationError, match="Value must be one of"):
171+
with pytest.raises(ValidationError, match="must be one of ..Color"):
131172
ModelWithColor(color="wrong_type").validate()
132173

133174

0 commit comments

Comments
 (0)