|
| 1 | +from typing import List |
| 2 | + |
1 | 3 | import mongoengine |
2 | | -from pymongo import ReadPreference, uri_parser |
3 | 4 |
|
4 | 5 | __all__ = ( |
5 | 6 | "create_connections", |
6 | 7 | "get_connection_settings", |
7 | | - "InvalidSettingsError", |
8 | 8 | ) |
9 | 9 |
|
10 | 10 |
|
11 | | -MONGODB_CONF_VARS = ( |
12 | | - "MONGODB_ALIAS", |
13 | | - "MONGODB_DB", |
14 | | - "MONGODB_HOST", |
15 | | - "MONGODB_IS_MOCK", |
16 | | - "MONGODB_PASSWORD", |
17 | | - "MONGODB_PORT", |
18 | | - "MONGODB_USERNAME", |
19 | | - "MONGODB_CONNECT", |
20 | | - "MONGODB_TZ_AWARE", |
21 | | -) |
22 | | - |
| 11 | +def _get_name(setting_name: str) -> str: |
| 12 | + """ |
| 13 | + Return known pymongo setting name, or lower case name for unknown. |
23 | 14 |
|
24 | | -class InvalidSettingsError(Exception): |
25 | | - pass |
| 15 | + This problem discovered in issue #451. As mentioned there pymongo settings are not |
| 16 | + case-sensitive, but mongoengine use exact name of some settings for matching, |
| 17 | + overwriting pymongo behaviour. |
26 | 18 |
|
| 19 | + This function address this issue, and potentially address cases when pymongo will |
| 20 | + become case-sensitive in some settings by same reasons as mongoengine done. |
27 | 21 |
|
28 | | -def _sanitize_settings(settings): |
29 | | - """Given a dict of connection settings, sanitize the keys and fall |
30 | | - back to some sane defaults. |
| 22 | + Based on pymongo 4.1.1 settings. |
31 | 23 | """ |
32 | | - # Remove the "MONGODB_" prefix and make all settings keys lower case. |
| 24 | + KNOWN_CAMEL_CASE_SETTINGS = { |
| 25 | + "directconnection": "directConnection", |
| 26 | + "maxpoolsize": "maxPoolSize", |
| 27 | + "minpoolsize": "minPoolSize", |
| 28 | + "maxidletimems": "maxIdleTimeMS", |
| 29 | + "maxconnecting": "maxConnecting", |
| 30 | + "sockettimeoutms": "socketTimeoutMS", |
| 31 | + "connecttimeoutms": "connectTimeoutMS", |
| 32 | + "serverselectiontimeoutms": "serverSelectionTimeoutMS", |
| 33 | + "waitqueuetimeoutms": "waitQueueTimeoutMS", |
| 34 | + "heartbeatfrequencyms": "heartbeatFrequencyMS", |
| 35 | + "retrywrites": "retryWrites", |
| 36 | + "retryreads": "retryReads", |
| 37 | + "zlibcompressionlevel": "zlibCompressionLevel", |
| 38 | + "uuidrepresentation": "uuidRepresentation", |
| 39 | + "srvservicename": "srvServiceName", |
| 40 | + "wtimeoutms": "wTimeoutMS", |
| 41 | + "replicaset": "replicaSet", |
| 42 | + "readpreference": "readPreference", |
| 43 | + "readpreferencetags": "readPreferenceTags", |
| 44 | + "maxstalenessseconds": "maxStalenessSeconds", |
| 45 | + "authsource": "authSource", |
| 46 | + "authmechanism": "authMechanism", |
| 47 | + "authmechanismproperties": "authMechanismProperties", |
| 48 | + "tlsinsecure": "tlsInsecure", |
| 49 | + "tlsallowinvalidcertificates": "tlsAllowInvalidCertificates", |
| 50 | + "tlsallowinvalidhostnames": "tlsAllowInvalidHostnames", |
| 51 | + "tlscafile": "tlsCAFile", |
| 52 | + "tlscertificatekeyfile": "tlsCertificateKeyFile", |
| 53 | + "tlscrlfile": "tlsCRLFile", |
| 54 | + "tlscertificatekeyfilepassword": "tlsCertificateKeyFilePassword", |
| 55 | + "tlsdisableocspendpointcheck": "tlsDisableOCSPEndpointCheck", |
| 56 | + "readconcernlevel": "readConcernLevel", |
| 57 | + } |
| 58 | + _setting_name = KNOWN_CAMEL_CASE_SETTINGS.get(setting_name.lower()) |
| 59 | + return setting_name.lower() if _setting_name is None else _setting_name |
| 60 | + |
| 61 | + |
| 62 | +def _sanitize_settings(settings: dict) -> dict: |
| 63 | + """Remove MONGODB_ prefix from dict values, to correct bypass to mongoengine.""" |
33 | 64 | resolved_settings = {} |
34 | 65 | for k, v in settings.items(): |
35 | | - if k.startswith("MONGODB_"): |
36 | | - k = k[len("MONGODB_") :] |
37 | | - k = k.lower() |
38 | | - resolved_settings[k] = v |
39 | | - |
40 | | - # Handle uri style connections |
41 | | - if "://" in resolved_settings.get("host", ""): |
42 | | - # this section pulls the database name from the URI |
43 | | - # PyMongo requires URI to start with mongodb:// to parse |
44 | | - # this workaround allows mongomock to work |
45 | | - uri_to_check = resolved_settings["host"] |
46 | | - |
47 | | - if uri_to_check.startswith("mongomock://"): |
48 | | - uri_to_check = uri_to_check.replace("mongomock://", "mongodb://") |
49 | | - |
50 | | - uri_dict = uri_parser.parse_uri(uri_to_check) |
51 | | - resolved_settings["db"] = uri_dict["database"] |
52 | | - |
53 | | - # Add a default name param or use the "db" key if exists |
54 | | - if resolved_settings.get("db"): |
55 | | - resolved_settings["name"] = resolved_settings.pop("db") |
56 | | - else: |
57 | | - resolved_settings["name"] = "test" |
58 | | - |
59 | | - # Add various default values. |
60 | | - resolved_settings["alias"] = resolved_settings.get( |
61 | | - "alias", mongoengine.DEFAULT_CONNECTION_NAME |
62 | | - ) |
63 | | - # TODO do we have to specify it here? MongoEngine should take care of that |
64 | | - resolved_settings["host"] = resolved_settings.get("host", "localhost") |
65 | | - # TODO this is the default host in pymongo.mongo_client.MongoClient, we may |
66 | | - # not need to explicitly set a default here |
67 | | - resolved_settings["port"] = resolved_settings.get("port", 27017) |
68 | | - # TODO this is the default port in pymongo.mongo_client.MongoClient, we may |
69 | | - # not need to explicitly set a default here |
70 | | - |
71 | | - # Default to ReadPreference.PRIMARY if no read_preference is supplied |
72 | | - resolved_settings["read_preference"] = resolved_settings.get( |
73 | | - "read_preference", ReadPreference.PRIMARY |
74 | | - ) |
75 | | - |
76 | | - # Clean up empty values |
77 | | - for k, v in list(resolved_settings.items()): |
78 | | - if v is None: |
79 | | - del resolved_settings[k] |
| 66 | + # Replace with k.lower().removeprefix("mongodb_") when python 3.8 support ends. |
| 67 | + key = _get_name(k[8:]) if k.lower().startswith("mongodb_") else _get_name(k) |
| 68 | + resolved_settings[key] = v |
80 | 69 |
|
81 | 70 | return resolved_settings |
82 | 71 |
|
83 | 72 |
|
84 | | -def get_connection_settings(config): |
| 73 | +def get_connection_settings(config: dict) -> List[dict]: |
85 | 74 | """ |
86 | 75 | Given a config dict, return a sanitized dict of MongoDB connection |
87 | 76 | settings that we can then use to establish connections. For new |
88 | 77 | applications, settings should exist in a ``MONGODB_SETTINGS`` key, but |
89 | 78 | for backward compatibility we also support several config keys |
90 | 79 | prefixed by ``MONGODB_``, e.g. ``MONGODB_HOST``, ``MONGODB_PORT``, etc. |
91 | 80 | """ |
| 81 | + |
| 82 | + # If no "MONGODB_SETTINGS", sanitize the "MONGODB_" keys as single connection. |
| 83 | + if "MONGODB_SETTINGS" not in config: |
| 84 | + config = {k: v for k, v in config.items() if k.lower().startswith("mongodb_")} |
| 85 | + return [_sanitize_settings(config)] |
| 86 | + |
92 | 87 | # Sanitize all the settings living under a "MONGODB_SETTINGS" config var |
93 | | - if "MONGODB_SETTINGS" in config: |
94 | | - settings = config["MONGODB_SETTINGS"] |
| 88 | + settings = config["MONGODB_SETTINGS"] |
95 | 89 |
|
96 | | - # If MONGODB_SETTINGS is a list of settings dicts, sanitize each |
97 | | - # dict separately. |
98 | | - if isinstance(settings, list): |
99 | | - return [_sanitize_settings(setting) for setting in settings] |
100 | | - else: |
101 | | - return _sanitize_settings(settings) |
| 90 | + # If MONGODB_SETTINGS is a list of settings dicts, sanitize each dict separately. |
| 91 | + if isinstance(settings, list): |
| 92 | + return [_sanitize_settings(settings_dict) for settings_dict in settings] |
102 | 93 |
|
103 | | - else: |
104 | | - config = {k: v for k, v in config.items() if k in MONGODB_CONF_VARS} |
105 | | - return _sanitize_settings(config) |
| 94 | + # Otherwise, it should be a single dict describing a single connection. |
| 95 | + return [_sanitize_settings(settings)] |
106 | 96 |
|
107 | 97 |
|
108 | | -def create_connections(config): |
| 98 | +def create_connections(config: dict): |
109 | 99 | """ |
110 | 100 | Given Flask application's config dict, extract relevant config vars |
111 | 101 | out of it and establish MongoEngine connection(s) based on them. |
112 | 102 | """ |
113 | | - # Validate that the config is a dict |
114 | | - if config is None or not isinstance(config, dict): |
115 | | - raise InvalidSettingsError("Invalid application configuration") |
| 103 | + # Validate that the config is a dict and dict is not empty |
| 104 | + if not config or not isinstance(config, dict): |
| 105 | + raise TypeError(f"Config dictionary expected, but {type(config)} received.") |
116 | 106 |
|
117 | 107 | # Get sanitized connection settings based on the config |
118 | | - conn_settings = get_connection_settings(config) |
119 | | - |
120 | | - # If conn_settings is a list, set up each item as a separate connection |
121 | | - # and return a dict of connection aliases and their connections. |
122 | | - if isinstance(conn_settings, list): |
123 | | - connections = {} |
124 | | - for each in conn_settings: |
125 | | - alias = each["alias"] |
126 | | - connections[alias] = _connect(each) |
127 | | - return connections |
| 108 | + connection_settings = get_connection_settings(config) |
128 | 109 |
|
129 | | - # Otherwise, return a single connection |
130 | | - return _connect(conn_settings) |
| 110 | + connections = {} |
| 111 | + for connection_setting in connection_settings: |
| 112 | + alias = connection_setting.setdefault( |
| 113 | + "alias", mongoengine.DEFAULT_CONNECTION_NAME |
| 114 | + ) |
| 115 | + connections[alias] = mongoengine.connect(**connection_setting) |
131 | 116 |
|
132 | | - |
133 | | -def _connect(conn_settings): |
134 | | - """Given a dict of connection settings, create a connection to |
135 | | - MongoDB by calling {func}`mongoengine.connect` and return its result. |
136 | | - """ |
137 | | - db_name = conn_settings.pop("name") |
138 | | - return mongoengine.connect(db_name, **conn_settings) |
| 117 | + return connections |
0 commit comments