Skip to content

Commit e8abddf

Browse files
authored
Merge pull request #489 from MongoEngine/orm_module_simplification
Updated: ORM module functions split for better maintenance
2 parents f8af880 + 9a4adbd commit e8abddf

File tree

6 files changed

+146
-73
lines changed

6 files changed

+146
-73
lines changed

flask_mongoengine/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@
55
from mongoengine.errors import DoesNotExist
66
from mongoengine.queryset import QuerySet
77

8+
from flask_mongoengine import db_fields
89
from flask_mongoengine.connection import *
910
from flask_mongoengine.json import override_json_encoder
1011
from flask_mongoengine.pagination import *
1112
from flask_mongoengine.sessions import *
1213

13-
try:
14-
from flask_mongoengine.wtf import db_fields
15-
except ImportError:
16-
from mongoengine import fields as db_fields
17-
1814

1915
def current_mongoengine_instance():
2016
"""Return a MongoEngine instance associated with current Flask app."""

flask_mongoengine/wtf/db_fields.py renamed to flask_mongoengine/db_fields.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
"""Responsible for mongoengine fields extension, if WTFForms integration used."""
2-
from mongoengine import fields
3-
42
__all__ = [
53
"WtfFieldMixin",
64
"BinaryField",
@@ -43,6 +41,16 @@
4341
"URLField",
4442
"UUIDField",
4543
]
44+
from typing import Callable, List, Optional, Union
45+
46+
from mongoengine import fields
47+
48+
try:
49+
from wtforms import fields as wtf_fields
50+
from wtforms import validators as wtf_validators
51+
except ImportError: # pragma: no cover
52+
wtf_fields = None
53+
wtf_validators = None
4654

4755

4856
class WtfFieldMixin:
@@ -58,28 +66,35 @@ class WtfFieldMixin:
5866
:param kwargs: keyword arguments silently bypassed to normal mongoengine fields
5967
"""
6068

61-
def __init__(self, *, validators=None, filters=None, **kwargs):
69+
def __init__(
70+
self,
71+
*,
72+
validators: Optional[Union[List, Callable]] = None,
73+
filters: Optional[Union[List, Callable]] = None,
74+
**kwargs,
75+
):
6276
self.validators = self._ensure_callable_or_list(validators, "validators")
6377
self.filters = self._ensure_callable_or_list(filters, "filters")
6478

6579
super().__init__(**kwargs)
6680

67-
def _ensure_callable_or_list(self, field, msg_flag):
81+
@staticmethod
82+
def _ensure_callable_or_list(argument, msg_flag: str) -> Optional[List]:
6883
"""
69-
Ensure the value submitted via field is either
70-
a callable object to convert to list or it is
71-
in fact a valid list value.
84+
Ensure submitted argument value is a callable object or valid list value.
7285
86+
:param argument: Argument input to make verification on.
87+
:param msg_flag: Argument string name for error message.
7388
"""
74-
if field is not None:
75-
if callable(field):
76-
field = [field]
77-
else:
78-
msg = "Argument '%s' must be a list value" % msg_flag
79-
if not isinstance(field, list):
80-
raise TypeError(msg)
81-
82-
return field
89+
if argument is None:
90+
return None
91+
92+
if callable(argument):
93+
return [argument]
94+
elif not isinstance(argument, list):
95+
raise TypeError(f"Argument '{msg_flag}' must be a list value")
96+
97+
return argument
8398

8499

85100
class BinaryField(WtfFieldMixin, fields.BinaryField):

flask_mongoengine/decorators.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import functools
2+
import logging
3+
4+
try:
5+
import wtforms # noqa
6+
7+
wtf_installed = True
8+
except ImportError: # pragma: no cover
9+
wtf_installed = False
10+
11+
logger = logging.getLogger("flask_mongoengine")
12+
13+
14+
def wtf_required(func):
15+
"""Special decorator to warn user on incorrect installation."""
16+
17+
@functools.wraps(func)
18+
def wrapped(*args, **kwargs):
19+
if not wtf_installed:
20+
logger.error(f"WTForms not installed. Function '{func.__name__}' aborted.")
21+
return None
22+
23+
return func(*args, **kwargs)
24+
25+
return wrapped

flask_mongoengine/wtf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""WTFForms integration module init file."""
2-
from flask_mongoengine.wtf.db_fields import * # noqa
2+
from flask_mongoengine.wtf.orm import model_fields, model_form # noqa

flask_mongoengine/wtf/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
from typing import Type, Union
2+
3+
import mongoengine
14
from flask_wtf import FlaskForm
25
from flask_wtf.form import _Auto
36

47

58
class ModelForm(FlaskForm):
69
"""A WTForms mongoengine model form"""
710

11+
model_class: Type[Union[mongoengine.Document, mongoengine.DynamicDocument]]
12+
813
def __init__(self, formdata=_Auto, **kwargs):
914
self.instance = kwargs.pop("instance", None) or kwargs.get("obj")
1015
if self.instance and not formdata:

flask_mongoengine/wtf/orm.py

Lines changed: 83 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
"""
44
import decimal
55
from collections import OrderedDict
6+
from typing import List, Optional, Type
67

78
from bson import ObjectId
89
from mongoengine import ReferenceField
10+
from mongoengine.base import BaseDocument, DocumentMetaclass
911
from wtforms import fields as f
1012
from wtforms import validators
1113

@@ -45,8 +47,8 @@ def __init__(self, converters=None):
4547

4648
self.converters = converters
4749

48-
def convert(self, model, field, field_args):
49-
kwargs = {
50+
def _generate_convert_base_kwargs(self, field, field_args) -> dict:
51+
kwargs: dict = {
5052
"label": getattr(field, "verbose_name", field.name),
5153
"description": getattr(field, "help_text", None) or "",
5254
"validators": getattr(field, "validators", None) or [],
@@ -56,42 +58,49 @@ def convert(self, model, field, field_args):
5658
if field_args:
5759
kwargs.update(field_args)
5860

59-
if kwargs["validators"]:
60-
# Create a copy of the list since we will be modifying it, and if
61-
# validators set as shared list between fields - duplicates/conflicts may
62-
# be created.
63-
kwargs["validators"] = list(kwargs["validators"])
64-
61+
# Create a copy of the lists since we will be modifying it, and if
62+
# validators set as shared list between fields - duplicates/conflicts may
63+
# be created.
64+
kwargs["validators"] = list(kwargs["validators"])
65+
kwargs["filters"] = list(kwargs["filters"])
6566
if field.required:
6667
kwargs["validators"].append(validators.InputRequired())
6768
else:
6869
kwargs["validators"].append(validators.Optional())
6970

70-
ftype = type(field).__name__
71+
return kwargs
7172

72-
if field.choices:
73-
kwargs["choices"] = field.choices
74-
75-
if ftype in self.converters:
76-
kwargs["coerce"] = self.coerce(ftype)
77-
multiple_field = kwargs.pop("multiple", False)
78-
radio_field = kwargs.pop("radio", False)
79-
if multiple_field:
80-
return f.SelectMultipleField(**kwargs)
81-
if radio_field:
82-
return f.RadioField(**kwargs)
83-
return f.SelectField(**kwargs)
73+
def _process_convert_for_choice_fields(self, field, field_class, kwargs):
74+
kwargs["choices"] = field.choices
75+
kwargs["coerce"] = self.coerce(field_class)
76+
if kwargs.pop("multiple", False):
77+
return f.SelectMultipleField(**kwargs)
78+
if kwargs.pop("radio", False):
79+
return f.RadioField(**kwargs)
80+
return f.SelectField(**kwargs)
8481

82+
def convert(self, model, field, field_args):
8583
if hasattr(field, "to_form_field"):
86-
return field.to_form_field(model, kwargs)
84+
return field.to_form_field(model, field_args)
85+
86+
field_class = type(field).__name__
87+
88+
if field_class not in self.converters:
89+
raise NotImplementedError(
90+
f"No converter for: {field_class}, exclude it from form generation."
91+
)
92+
93+
kwargs = self._generate_convert_base_kwargs(field, field_args)
94+
95+
if field.choices:
96+
return self._process_convert_for_choice_fields(field, field_class, kwargs)
8797

8898
if hasattr(field, "field") and isinstance(field.field, ReferenceField):
8999
kwargs["label_modifier"] = getattr(
90100
model, f"{field.name}_label_modifier", None
91101
)
92102

93-
if ftype in self.converters:
94-
return self.converters[ftype](model, field, kwargs)
103+
return self.converters[field_class](model, field, kwargs)
95104

96105
@classmethod
97106
def _string_common(cls, model, field, kwargs):
@@ -237,46 +246,68 @@ def coerce(self, field_type):
237246
return coercions.get(field_type, str)
238247

239248

240-
def model_fields(model, only=None, exclude=None, field_args=None, converter=None):
249+
def _get_fields_names(
250+
model,
251+
only: Optional[List[str]],
252+
exclude: Optional[List[str]],
253+
) -> List[str]:
241254
"""
242-
Generate a dictionary of fields for a given database model.
255+
Filter fields names for further form generation.
243256
244-
See `model_form` docstring for description of parameters.
257+
:param model: Source model class for fields list retrieval
258+
:param only: If provided, only these field names will have fields definition.
259+
:param exclude: If provided, field names will be excluded from fields definition.
260+
All other field names will have fields.
245261
"""
246-
from mongoengine.base import BaseDocument, DocumentMetaclass
262+
field_names = model._fields_ordered
247263

248-
if not isinstance(model, (BaseDocument, DocumentMetaclass)):
264+
if only:
265+
field_names = [field for field in only if field in field_names]
266+
elif exclude:
267+
field_names = [field for field in field_names if field not in set(exclude)]
268+
269+
return field_names
270+
271+
272+
def model_fields(
273+
model: Type[BaseDocument],
274+
only: Optional[List[str]] = None,
275+
exclude: Optional[List[str]] = None,
276+
field_args=None,
277+
converter=None,
278+
) -> OrderedDict:
279+
"""
280+
Generate a dictionary of fields for a given database model.
281+
282+
See :func:`model_form` docstring for description of parameters.
283+
"""
284+
if not issubclass(model, (BaseDocument, DocumentMetaclass)):
249285
raise TypeError("model must be a mongoengine Document schema")
250286

251287
converter = converter or ModelConverter()
252288
field_args = field_args or {}
289+
form_fields_dict = OrderedDict()
290+
# noinspection PyTypeChecker
291+
fields_names = _get_fields_names(model, only, exclude)
253292

254-
names = ((k, v.creation_counter) for k, v in model._fields.items())
255-
field_names = [n[0] for n in sorted(names, key=lambda n: n[1])]
293+
for field_name in fields_names:
294+
# noinspection PyUnresolvedReferences
295+
model_field = model._fields[field_name]
296+
form_field = converter.convert(model, model_field, field_args.get(field_name))
297+
if form_field is not None:
298+
form_fields_dict[field_name] = form_field
256299

257-
if only:
258-
field_names = [x for x in only if x in set(field_names)]
259-
elif exclude:
260-
field_names = [x for x in field_names if x not in set(exclude)]
261-
262-
field_dict = OrderedDict()
263-
for name in field_names:
264-
model_field = model._fields[name]
265-
field = converter.convert(model, model_field, field_args.get(name))
266-
if field is not None:
267-
field_dict[name] = field
268-
269-
return field_dict
300+
return form_fields_dict
270301

271302

272303
def model_form(
273-
model,
274-
base_class=ModelForm,
275-
only=None,
276-
exclude=None,
304+
model: Type[BaseDocument],
305+
base_class: Type[ModelForm] = ModelForm,
306+
only: Optional[List[str]] = None,
307+
exclude: Optional[List[str]] = None,
277308
field_args=None,
278309
converter=None,
279-
):
310+
) -> Type[ModelForm]:
280311
"""
281312
Create a wtforms Form for a given mongoengine Document schema::
282313
@@ -287,7 +318,7 @@ def model_form(
287318
:param model:
288319
A mongoengine Document schema class
289320
:param base_class:
290-
Base form class to extend from. Must be a ``wtforms.Form`` subclass.
321+
Base form class to extend from. Must be a :class:`.ModelForm` subclass.
291322
:param only:
292323
An optional iterable with the property names that should be included in
293324
the form. Only these properties will have fields.
@@ -299,8 +330,9 @@ def model_form(
299330
to construct each field object.
300331
:param converter:
301332
A converter to generate the fields based on the model properties. If
302-
not set, ``ModelConverter`` is used.
333+
not set, :class:`.ModelConverter` is used.
303334
"""
304335
field_dict = model_fields(model, only, exclude, field_args, converter)
305336
field_dict["model_class"] = model
337+
# noinspection PyTypeChecker
306338
return type(f"{model.__name__}Form", (base_class,), field_dict)

0 commit comments

Comments
 (0)