Skip to content

Commit 6b18011

Browse files
authored
Merge pull request #429 from MongoEngine/rewrite-connection-module
Rewrite connection module
2 parents 89eb76e + fb83fd6 commit 6b18011

File tree

5 files changed

+108
-135
lines changed

5 files changed

+108
-135
lines changed

flask_mongoengine/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def current_mongoengine_instance():
1818
return k
1919

2020

21-
class MongoEngine(object):
21+
class MongoEngine:
2222
"""Main class used for initialization of Flask-MongoEngine."""
2323

2424
def __init__(self, app=None, config=None):
@@ -110,7 +110,7 @@ def init_app(self, app, config=None):
110110
app.extensions["mongoengine"][self] = s
111111

112112
@property
113-
def connection(self):
113+
def connection(self) -> dict:
114114
"""
115115
Return MongoDB connection(s) associated with this MongoEngine
116116
instance.

flask_mongoengine/connection.py

Lines changed: 79 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,138 +1,117 @@
1+
from typing import List
2+
13
import mongoengine
2-
from pymongo import ReadPreference, uri_parser
34

45
__all__ = (
56
"create_connections",
67
"get_connection_settings",
7-
"InvalidSettingsError",
88
)
99

1010

11-
MONGODB_CONF_VARS = (
12-
"MONGODB_ALIAS",
13-
"MONGODB_DB",
14-
"MONGODB_HOST",
15-
"MONGODB_IS_MOCK",
16-
"MONGODB_PASSWORD",
17-
"MONGODB_PORT",
18-
"MONGODB_USERNAME",
19-
"MONGODB_CONNECT",
20-
"MONGODB_TZ_AWARE",
21-
)
22-
11+
def _get_name(setting_name: str) -> str:
12+
"""
13+
Return known pymongo setting name, or lower case name for unknown.
2314
24-
class InvalidSettingsError(Exception):
25-
pass
15+
This problem discovered in issue #451. As mentioned there pymongo settings are not
16+
case-sensitive, but mongoengine use exact name of some settings for matching,
17+
overwriting pymongo behaviour.
2618
19+
This function address this issue, and potentially address cases when pymongo will
20+
become case-sensitive in some settings by same reasons as mongoengine done.
2721
28-
def _sanitize_settings(settings):
29-
"""Given a dict of connection settings, sanitize the keys and fall
30-
back to some sane defaults.
22+
Based on pymongo 4.1.1 settings.
3123
"""
32-
# Remove the "MONGODB_" prefix and make all settings keys lower case.
24+
KNOWN_CAMEL_CASE_SETTINGS = {
25+
"directconnection": "directConnection",
26+
"maxpoolsize": "maxPoolSize",
27+
"minpoolsize": "minPoolSize",
28+
"maxidletimems": "maxIdleTimeMS",
29+
"maxconnecting": "maxConnecting",
30+
"sockettimeoutms": "socketTimeoutMS",
31+
"connecttimeoutms": "connectTimeoutMS",
32+
"serverselectiontimeoutms": "serverSelectionTimeoutMS",
33+
"waitqueuetimeoutms": "waitQueueTimeoutMS",
34+
"heartbeatfrequencyms": "heartbeatFrequencyMS",
35+
"retrywrites": "retryWrites",
36+
"retryreads": "retryReads",
37+
"zlibcompressionlevel": "zlibCompressionLevel",
38+
"uuidrepresentation": "uuidRepresentation",
39+
"srvservicename": "srvServiceName",
40+
"wtimeoutms": "wTimeoutMS",
41+
"replicaset": "replicaSet",
42+
"readpreference": "readPreference",
43+
"readpreferencetags": "readPreferenceTags",
44+
"maxstalenessseconds": "maxStalenessSeconds",
45+
"authsource": "authSource",
46+
"authmechanism": "authMechanism",
47+
"authmechanismproperties": "authMechanismProperties",
48+
"tlsinsecure": "tlsInsecure",
49+
"tlsallowinvalidcertificates": "tlsAllowInvalidCertificates",
50+
"tlsallowinvalidhostnames": "tlsAllowInvalidHostnames",
51+
"tlscafile": "tlsCAFile",
52+
"tlscertificatekeyfile": "tlsCertificateKeyFile",
53+
"tlscrlfile": "tlsCRLFile",
54+
"tlscertificatekeyfilepassword": "tlsCertificateKeyFilePassword",
55+
"tlsdisableocspendpointcheck": "tlsDisableOCSPEndpointCheck",
56+
"readconcernlevel": "readConcernLevel",
57+
}
58+
_setting_name = KNOWN_CAMEL_CASE_SETTINGS.get(setting_name.lower())
59+
return setting_name.lower() if _setting_name is None else _setting_name
60+
61+
62+
def _sanitize_settings(settings: dict) -> dict:
63+
"""Remove MONGODB_ prefix from dict values, to correct bypass to mongoengine."""
3364
resolved_settings = {}
3465
for k, v in settings.items():
35-
if k.startswith("MONGODB_"):
36-
k = k[len("MONGODB_") :]
37-
k = k.lower()
38-
resolved_settings[k] = v
39-
40-
# Handle uri style connections
41-
if "://" in resolved_settings.get("host", ""):
42-
# this section pulls the database name from the URI
43-
# PyMongo requires URI to start with mongodb:// to parse
44-
# this workaround allows mongomock to work
45-
uri_to_check = resolved_settings["host"]
46-
47-
if uri_to_check.startswith("mongomock://"):
48-
uri_to_check = uri_to_check.replace("mongomock://", "mongodb://")
49-
50-
uri_dict = uri_parser.parse_uri(uri_to_check)
51-
resolved_settings["db"] = uri_dict["database"]
52-
53-
# Add a default name param or use the "db" key if exists
54-
if resolved_settings.get("db"):
55-
resolved_settings["name"] = resolved_settings.pop("db")
56-
else:
57-
resolved_settings["name"] = "test"
58-
59-
# Add various default values.
60-
resolved_settings["alias"] = resolved_settings.get(
61-
"alias", mongoengine.DEFAULT_CONNECTION_NAME
62-
)
63-
# TODO do we have to specify it here? MongoEngine should take care of that
64-
resolved_settings["host"] = resolved_settings.get("host", "localhost")
65-
# TODO this is the default host in pymongo.mongo_client.MongoClient, we may
66-
# not need to explicitly set a default here
67-
resolved_settings["port"] = resolved_settings.get("port", 27017)
68-
# TODO this is the default port in pymongo.mongo_client.MongoClient, we may
69-
# not need to explicitly set a default here
70-
71-
# Default to ReadPreference.PRIMARY if no read_preference is supplied
72-
resolved_settings["read_preference"] = resolved_settings.get(
73-
"read_preference", ReadPreference.PRIMARY
74-
)
75-
76-
# Clean up empty values
77-
for k, v in list(resolved_settings.items()):
78-
if v is None:
79-
del resolved_settings[k]
66+
# Replace with k.lower().removeprefix("mongodb_") when python 3.8 support ends.
67+
key = _get_name(k[8:]) if k.lower().startswith("mongodb_") else _get_name(k)
68+
resolved_settings[key] = v
8069

8170
return resolved_settings
8271

8372

84-
def get_connection_settings(config):
73+
def get_connection_settings(config: dict) -> List[dict]:
8574
"""
8675
Given a config dict, return a sanitized dict of MongoDB connection
8776
settings that we can then use to establish connections. For new
8877
applications, settings should exist in a ``MONGODB_SETTINGS`` key, but
8978
for backward compatibility we also support several config keys
9079
prefixed by ``MONGODB_``, e.g. ``MONGODB_HOST``, ``MONGODB_PORT``, etc.
9180
"""
81+
82+
# If no "MONGODB_SETTINGS", sanitize the "MONGODB_" keys as single connection.
83+
if "MONGODB_SETTINGS" not in config:
84+
config = {k: v for k, v in config.items() if k.lower().startswith("mongodb_")}
85+
return [_sanitize_settings(config)]
86+
9287
# Sanitize all the settings living under a "MONGODB_SETTINGS" config var
93-
if "MONGODB_SETTINGS" in config:
94-
settings = config["MONGODB_SETTINGS"]
88+
settings = config["MONGODB_SETTINGS"]
9589

96-
# If MONGODB_SETTINGS is a list of settings dicts, sanitize each
97-
# dict separately.
98-
if isinstance(settings, list):
99-
return [_sanitize_settings(setting) for setting in settings]
100-
else:
101-
return _sanitize_settings(settings)
90+
# If MONGODB_SETTINGS is a list of settings dicts, sanitize each dict separately.
91+
if isinstance(settings, list):
92+
return [_sanitize_settings(settings_dict) for settings_dict in settings]
10293

103-
else:
104-
config = {k: v for k, v in config.items() if k in MONGODB_CONF_VARS}
105-
return _sanitize_settings(config)
94+
# Otherwise, it should be a single dict describing a single connection.
95+
return [_sanitize_settings(settings)]
10696

10797

108-
def create_connections(config):
98+
def create_connections(config: dict):
10999
"""
110100
Given Flask application's config dict, extract relevant config vars
111101
out of it and establish MongoEngine connection(s) based on them.
112102
"""
113-
# Validate that the config is a dict
114-
if config is None or not isinstance(config, dict):
115-
raise InvalidSettingsError("Invalid application configuration")
103+
# Validate that the config is a dict and dict is not empty
104+
if not config or not isinstance(config, dict):
105+
raise TypeError(f"Config dictionary expected, but {type(config)} received.")
116106

117107
# Get sanitized connection settings based on the config
118-
conn_settings = get_connection_settings(config)
119-
120-
# If conn_settings is a list, set up each item as a separate connection
121-
# and return a dict of connection aliases and their connections.
122-
if isinstance(conn_settings, list):
123-
connections = {}
124-
for each in conn_settings:
125-
alias = each["alias"]
126-
connections[alias] = _connect(each)
127-
return connections
108+
connection_settings = get_connection_settings(config)
128109

129-
# Otherwise, return a single connection
130-
return _connect(conn_settings)
110+
connections = {}
111+
for connection_setting in connection_settings:
112+
alias = connection_setting.setdefault(
113+
"alias", mongoengine.DEFAULT_CONNECTION_NAME
114+
)
115+
connections[alias] = mongoengine.connect(**connection_setting)
131116

132-
133-
def _connect(conn_settings):
134-
"""Given a dict of connection settings, create a connection to
135-
MongoDB by calling {func}`mongoengine.connect` and return its result.
136-
"""
137-
db_name = conn_settings.pop("name")
138-
return mongoengine.connect(db_name, **conn_settings)
117+
return connections

tests/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,22 @@ def app():
3939
def db(app):
4040
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
4141
test_db = MongoEngine(app)
42-
db_name = test_db.connection.get_database("flask_mongoengine_test_db").name
42+
db_name = (
43+
test_db.connection["default"].get_database("flask_mongoengine_test_db").name
44+
)
4345

4446
if not db_name.endswith("_test_db"):
4547
raise RuntimeError(
4648
f"DATABASE_URL must point to testing db, not to master db ({db_name})"
4749
)
4850

4951
# Clear database before tests, for cases when some test failed before.
50-
test_db.connection.drop_database(db_name)
52+
test_db.connection["default"].drop_database(db_name)
5153

5254
yield test_db
5355

5456
# Clear database after tests, for graceful exit.
55-
test_db.connection.drop_database(db_name)
57+
test_db.connection["default"].drop_database(db_name)
5658

5759

5860
@pytest.fixture()

tests/test_connection.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import mongoengine
2-
import pymongo
32
import pytest
43
from mongoengine.connection import ConnectionFailure
54
from mongoengine.context_managers import switch_db
@@ -11,6 +10,14 @@
1110
from flask_mongoengine import MongoEngine, current_mongoengine_instance
1211

1312

13+
def is_mongo_mock_installed() -> bool:
14+
try:
15+
import mongomock.__version__ # noqa
16+
except ImportError:
17+
return False
18+
return True
19+
20+
1421
def test_connection__should_use_defaults__if_no_settings_provided(app):
1522
"""Make sure a simple connection to a standalone MongoDB works."""
1623
db = MongoEngine()
@@ -129,6 +136,9 @@ def test_connection__should_parse_host_uri__if_host_formatted_as_uri(
129136
assert connection.PORT == 27017
130137

131138

139+
@pytest.mark.skipif(
140+
is_mongo_mock_installed(), reason="This test require mongomock not exist"
141+
)
132142
@pytest.mark.parametrize(
133143
("config_extension"),
134144
[
@@ -281,46 +291,26 @@ class Todo(db.Document):
281291
assert doc is not None
282292

283293

284-
def test_ignored_mongodb_prefix_config(app):
285-
"""Config starting by MONGODB_ but not used by flask-mongoengine
286-
should be ignored.
287-
"""
294+
def test_incorrect_value_with_mongodb_prefix__should_trigger_mongoengine_raise(app):
288295
db = MongoEngine()
289296
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
290297
# Invalid host, should trigger exception if used
291298
app.config["MONGODB_TEST_HOST"] = "dummy://localhost:27017/test"
292-
db.init_app(app)
293-
294-
connection = mongoengine.get_connection()
295-
mongo_engine_db = mongoengine.get_db()
296-
assert isinstance(mongo_engine_db, Database)
297-
assert isinstance(connection, MongoClient)
298-
assert mongo_engine_db.name == "flask_mongoengine_test_db"
299-
assert connection.HOST == "localhost"
300-
assert connection.PORT == 27017
299+
with pytest.raises(ConnectionFailure):
300+
db.init_app(app)
301301

302302

303303
def test_connection_kwargs(app):
304304
"""Make sure additional connection kwargs work."""
305305

306-
# Figure out whether to use "MAX_POOL_SIZE" or "MAXPOOLSIZE" based
307-
# on PyMongo version (former was changed to the latter as described
308-
# in https://jira.mongodb.org/browse/PYTHON-854)
309-
# TODO remove once PyMongo < 3.0 support is dropped
310-
if pymongo.version_tuple[0] >= 3:
311-
MAX_POOL_SIZE_KEY = "MAXPOOLSIZE"
312-
else:
313-
MAX_POOL_SIZE_KEY = "MAX_POOL_SIZE"
314-
315306
app.config["MONGODB_SETTINGS"] = {
316307
"ALIAS": "tz_aware_true",
317308
"DB": "flask_mongoengine_test_db",
318309
"TZ_AWARE": True,
319310
"READ_PREFERENCE": ReadPreference.SECONDARY,
320-
MAX_POOL_SIZE_KEY: 10,
311+
"MAXPOOLSIZE": 10,
321312
}
322313
db = MongoEngine(app)
323314

324-
assert db.connection.codec_options.tz_aware
325-
# assert db.connection.max_pool_size == 10
326-
assert db.connection.read_preference == ReadPreference.SECONDARY
315+
assert db.connection["tz_aware_true"].codec_options.tz_aware
316+
assert db.connection["tz_aware_true"].read_preference == ReadPreference.SECONDARY

tests/test_json.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,22 @@ def extended_db(app):
99
app.json_encoder = DummyEncoder
1010
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
1111
test_db = MongoEngine(app)
12-
db_name = test_db.connection.get_database("flask_mongoengine_test_db").name
12+
db_name = (
13+
test_db.connection["default"].get_database("flask_mongoengine_test_db").name
14+
)
1315

1416
if not db_name.endswith("_test_db"):
1517
raise RuntimeError(
1618
f"DATABASE_URL must point to testing db, not to master db ({db_name})"
1719
)
1820

1921
# Clear database before tests, for cases when some test failed before.
20-
test_db.connection.drop_database(db_name)
22+
test_db.connection["default"].drop_database(db_name)
2123

2224
yield test_db
2325

2426
# Clear database after tests, for graceful exit.
25-
test_db.connection.drop_database(db_name)
27+
test_db.connection["default"].drop_database(db_name)
2628

2729

2830
class DummyEncoder(flask.json.JSONEncoder):

0 commit comments

Comments
 (0)