1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
15+ """Shared helper methods for pymongo, bson, and gridfs test suites."""
1616from __future__ import annotations
1717
1818import asyncio
19- import base64
20- import gc
21- import multiprocessing
22- import os
23- import signal
24- import socket
25- import subprocess
26- import sys
2719import threading
28- import time
29- import traceback
30- import unittest
31- import warnings
32- from inspect import iscoroutinefunction
33- from pathlib import Path
20+ from typing import Optional
3421
22+ from bson import SON
3523from pymongo ._asyncio_task import create_task
36-
37- try :
38- import ipaddress
39-
40- HAVE_IPADDRESS = True
41- except ImportError :
42- HAVE_IPADDRESS = False
43- from functools import wraps
44- from typing import Any , Callable , Dict , Generator , Optional , no_type_check
45- from unittest import SkipTest
46-
47- from bson .son import SON
48- from pymongo import common , message
4924from pymongo .read_preferences import ReadPreference
50- from pymongo .ssl_support import HAVE_SSL , _ssl # type:ignore[attr-defined]
51- from pymongo .synchronous .uri_parser import parse_uri
52-
53- if HAVE_SSL :
54- import ssl
5525
5626_IS_SYNC = False
5727
58- # Enable debug output for uncollectable objects. PyPy does not have set_debug.
59- if hasattr (gc , "set_debug" ):
60- gc .set_debug (
61- gc .DEBUG_UNCOLLECTABLE | getattr (gc , "DEBUG_OBJECTS" , 0 ) | getattr (gc , "DEBUG_INSTANCES" , 0 )
62- )
63-
64- # The host and port of a single mongod or mongos, or the seed host
65- # for a replica set.
66- host = os .environ .get ("DB_IP" , "localhost" )
67- port = int (os .environ .get ("DB_PORT" , 27017 ))
68- IS_SRV = "mongodb+srv" in host
69-
70- db_user = os .environ .get ("DB_USER" , "user" )
71- db_pwd = os .environ .get ("DB_PASSWORD" , "password" )
72-
73- HERE = Path (__file__ ).absolute ()
74- if _IS_SYNC :
75- CERT_PATH = str (HERE .parent / "certificates" )
76- else :
77- CERT_PATH = str (HERE .parent .parent / "certificates" )
78- CLIENT_PEM = os .environ .get ("CLIENT_PEM" , os .path .join (CERT_PATH , "client.pem" ))
79- CA_PEM = os .environ .get ("CA_PEM" , os .path .join (CERT_PATH , "ca.pem" ))
80-
81- TLS_OPTIONS : Dict = {"tls" : True }
82- if CLIENT_PEM :
83- TLS_OPTIONS ["tlsCertificateKeyFile" ] = CLIENT_PEM
84- if CA_PEM :
85- TLS_OPTIONS ["tlsCAFile" ] = CA_PEM
86-
87- COMPRESSORS = os .environ .get ("COMPRESSORS" )
88- MONGODB_API_VERSION = os .environ .get ("MONGODB_API_VERSION" )
89- TEST_LOADBALANCER = bool (os .environ .get ("TEST_LOAD_BALANCER" ))
90- SINGLE_MONGOS_LB_URI = os .environ .get ("SINGLE_MONGOS_LB_URI" )
91- MULTI_MONGOS_LB_URI = os .environ .get ("MULTI_MONGOS_LB_URI" )
92-
93- if TEST_LOADBALANCER :
94- res = parse_uri (SINGLE_MONGOS_LB_URI or "" )
95- host , port = res ["nodelist" ][0 ]
96- db_user = res ["username" ] or db_user
97- db_pwd = res ["password" ] or db_pwd
98-
99-
100- # Shared KMS data.
101- LOCAL_MASTER_KEY = base64 .b64decode (
102- b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
103- b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
104- )
105- AWS_CREDS = {
106- "accessKeyId" : os .environ .get ("FLE_AWS_KEY" , "" ),
107- "secretAccessKey" : os .environ .get ("FLE_AWS_SECRET" , "" ),
108- }
109- AWS_CREDS_2 = {
110- "accessKeyId" : os .environ .get ("FLE_AWS_KEY2" , "" ),
111- "secretAccessKey" : os .environ .get ("FLE_AWS_SECRET2" , "" ),
112- }
113- AZURE_CREDS = {
114- "tenantId" : os .environ .get ("FLE_AZURE_TENANTID" , "" ),
115- "clientId" : os .environ .get ("FLE_AZURE_CLIENTID" , "" ),
116- "clientSecret" : os .environ .get ("FLE_AZURE_CLIENTSECRET" , "" ),
117- }
118- GCP_CREDS = {
119- "email" : os .environ .get ("FLE_GCP_EMAIL" , "" ),
120- "privateKey" : os .environ .get ("FLE_GCP_PRIVATEKEY" , "" ),
121- }
122- KMIP_CREDS = {"endpoint" : os .environ .get ("FLE_KMIP_ENDPOINT" , "localhost:5698" )}
123- AWS_TEMP_CREDS = {
124- "accessKeyId" : os .environ .get ("CSFLE_AWS_TEMP_ACCESS_KEY_ID" , "" ),
125- "secretAccessKey" : os .environ .get ("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY" , "" ),
126- "sessionToken" : os .environ .get ("CSFLE_AWS_TEMP_SESSION_TOKEN" , "" ),
127- }
128-
129- ALL_KMS_PROVIDERS = dict (
130- aws = AWS_CREDS ,
131- azure = AZURE_CREDS ,
132- gcp = GCP_CREDS ,
133- local = dict (key = LOCAL_MASTER_KEY ),
134- kmip = KMIP_CREDS ,
135- )
136- DEFAULT_KMS_TLS = dict (kmip = dict (tlsCAFile = CA_PEM , tlsCertificateKeyFile = CLIENT_PEM ))
137-
138- # Ensure Evergreen metadata doesn't result in truncation
139- os .environ .setdefault ("MONGOB_LOG_MAX_DOCUMENT_LENGTH" , "2000" )
140-
141-
142- def is_server_resolvable ():
143- """Returns True if 'server' is resolvable."""
144- socket_timeout = socket .getdefaulttimeout ()
145- socket .setdefaulttimeout (1 )
146- try :
147- try :
148- socket .gethostbyname ("server" )
149- return True
150- except OSError :
151- return False
152- finally :
153- socket .setdefaulttimeout (socket_timeout )
154-
155-
156- def _create_user (authdb , user , pwd = None , roles = None , ** kwargs ):
157- cmd = SON ([("createUser" , user )])
158- # X509 doesn't use a password
159- if pwd :
160- cmd ["pwd" ] = pwd
161- cmd ["roles" ] = roles or ["root" ]
162- cmd .update (** kwargs )
163- return authdb .command (cmd )
164-
16528
16629async def async_repl_set_step_down (client , ** kwargs ):
16730 """Run replSetStepDown, first unfreezing a secondary with replSetFreeze."""
@@ -173,216 +36,6 @@ async def async_repl_set_step_down(client, **kwargs):
17336 await client .admin .command (cmd )
17437
17538
176- class client_knobs :
177- def __init__ (
178- self ,
179- heartbeat_frequency = None ,
180- min_heartbeat_interval = None ,
181- kill_cursor_frequency = None ,
182- events_queue_frequency = None ,
183- ):
184- self .heartbeat_frequency = heartbeat_frequency
185- self .min_heartbeat_interval = min_heartbeat_interval
186- self .kill_cursor_frequency = kill_cursor_frequency
187- self .events_queue_frequency = events_queue_frequency
188-
189- self .old_heartbeat_frequency = None
190- self .old_min_heartbeat_interval = None
191- self .old_kill_cursor_frequency = None
192- self .old_events_queue_frequency = None
193- self ._enabled = False
194- self ._stack = None
195-
196- def enable (self ):
197- self .old_heartbeat_frequency = common .HEARTBEAT_FREQUENCY
198- self .old_min_heartbeat_interval = common .MIN_HEARTBEAT_INTERVAL
199- self .old_kill_cursor_frequency = common .KILL_CURSOR_FREQUENCY
200- self .old_events_queue_frequency = common .EVENTS_QUEUE_FREQUENCY
201-
202- if self .heartbeat_frequency is not None :
203- common .HEARTBEAT_FREQUENCY = self .heartbeat_frequency
204-
205- if self .min_heartbeat_interval is not None :
206- common .MIN_HEARTBEAT_INTERVAL = self .min_heartbeat_interval
207-
208- if self .kill_cursor_frequency is not None :
209- common .KILL_CURSOR_FREQUENCY = self .kill_cursor_frequency
210-
211- if self .events_queue_frequency is not None :
212- common .EVENTS_QUEUE_FREQUENCY = self .events_queue_frequency
213- self ._enabled = True
214- # Store the allocation traceback to catch non-disabled client_knobs.
215- self ._stack = "" .join (traceback .format_stack ())
216-
217- def __enter__ (self ):
218- self .enable ()
219-
220- @no_type_check
221- def disable (self ):
222- common .HEARTBEAT_FREQUENCY = self .old_heartbeat_frequency
223- common .MIN_HEARTBEAT_INTERVAL = self .old_min_heartbeat_interval
224- common .KILL_CURSOR_FREQUENCY = self .old_kill_cursor_frequency
225- common .EVENTS_QUEUE_FREQUENCY = self .old_events_queue_frequency
226- self ._enabled = False
227-
228- def __exit__ (self , exc_type , exc_val , exc_tb ):
229- self .disable ()
230-
231- def __call__ (self , func ):
232- def make_wrapper (f ):
233- @wraps (f )
234- async def wrap (* args , ** kwargs ):
235- with self :
236- return await f (* args , ** kwargs )
237-
238- return wrap
239-
240- return make_wrapper (func )
241-
242- def __del__ (self ):
243- if self ._enabled :
244- msg = (
245- "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
246- "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
247- "EVENTS_QUEUE_FREQUENCY={}, stack:\n {}" .format (
248- common .HEARTBEAT_FREQUENCY ,
249- common .MIN_HEARTBEAT_INTERVAL ,
250- common .KILL_CURSOR_FREQUENCY ,
251- common .EVENTS_QUEUE_FREQUENCY ,
252- self ._stack ,
253- )
254- )
255- self .disable ()
256- raise Exception (msg )
257-
258-
259- def _all_users (db ):
260- return {u ["user" ] for u in db .command ("usersInfo" ).get ("users" , [])}
261-
262-
263- def sanitize_cmd (cmd ):
264- cp = cmd .copy ()
265- cp .pop ("$clusterTime" , None )
266- cp .pop ("$db" , None )
267- cp .pop ("$readPreference" , None )
268- cp .pop ("lsid" , None )
269- if MONGODB_API_VERSION :
270- # Stable API parameters
271- cp .pop ("apiVersion" , None )
272- # OP_MSG encoding may move the payload type one field to the
273- # end of the command. Do the same here.
274- name = next (iter (cp ))
275- try :
276- identifier = message ._FIELD_MAP [name ]
277- docs = cp .pop (identifier )
278- cp [identifier ] = docs
279- except KeyError :
280- pass
281- return cp
282-
283-
284- def sanitize_reply (reply ):
285- cp = reply .copy ()
286- cp .pop ("$clusterTime" , None )
287- cp .pop ("operationTime" , None )
288- return cp
289-
290-
291- def print_thread_tracebacks () -> None :
292- """Print all Python thread tracebacks."""
293- for thread_id , frame in sys ._current_frames ().items ():
294- sys .stderr .write (f"\n --- Traceback for thread { thread_id } ---\n " )
295- traceback .print_stack (frame , file = sys .stderr )
296-
297-
298- def print_thread_stacks (pid : int ) -> None :
299- """Print all C-level thread stacks for a given process id."""
300- if sys .platform == "darwin" :
301- cmd = ["lldb" , "--attach-pid" , f"{ pid } " , "--batch" , "--one-line" , '"thread backtrace all"' ]
302- else :
303- cmd = ["gdb" , f"--pid={ pid } " , "--batch" , '--eval-command="thread apply all bt"' ]
304-
305- try :
306- res = subprocess .run (
307- cmd , stdout = subprocess .PIPE , stderr = subprocess .STDOUT , encoding = "utf-8"
308- )
309- except Exception as exc :
310- sys .stderr .write (f"Could not print C-level thread stacks because { cmd [0 ]} failed: { exc } " )
311- else :
312- sys .stderr .write (res .stdout )
313-
314-
315- # Global knobs to speed up the test suite.
316- global_knobs = client_knobs (events_queue_frequency = 0.05 )
317-
318-
319- def _get_executors (topology ):
320- executors = []
321- for server in topology ._servers .values ():
322- # Some MockMonitor do not have an _executor.
323- if hasattr (server ._monitor , "_executor" ):
324- executors .append (server ._monitor ._executor )
325- if hasattr (server ._monitor , "_rtt_monitor" ):
326- executors .append (server ._monitor ._rtt_monitor ._executor )
327- executors .append (topology ._Topology__events_executor )
328- if topology ._srv_monitor :
329- executors .append (topology ._srv_monitor ._executor )
330-
331- return [e for e in executors if e is not None ]
332-
333-
334- def print_running_topology (topology ):
335- running = [e for e in _get_executors (topology ) if not e ._stopped ]
336- if running :
337- print (
338- "WARNING: found Topology with running threads:\n "
339- f" Threads: { running } \n "
340- f" Topology: { topology } \n "
341- f" Creation traceback:\n { topology ._settings ._stack } "
342- )
343-
344-
345- def test_cases (suite ):
346- """Iterator over all TestCases within a TestSuite."""
347- for suite_or_case in suite ._tests :
348- if isinstance (suite_or_case , unittest .TestCase ):
349- # unittest.TestCase
350- yield suite_or_case
351- else :
352- # unittest.TestSuite
353- yield from test_cases (suite_or_case )
354-
355-
356- # Helper method to workaround https://bugs.python.org/issue21724
357- def clear_warning_registry ():
358- """Clear the __warningregistry__ for all modules."""
359- for _ , module in list (sys .modules .items ()):
360- if hasattr (module , "__warningregistry__" ):
361- module .__warningregistry__ = {} # type:ignore[attr-defined]
362-
363-
364- class SystemCertsPatcher :
365- def __init__ (self , ca_certs ):
366- if (
367- ssl .OPENSSL_VERSION .lower ().startswith ("libressl" )
368- and sys .platform == "darwin"
369- and not _ssl .IS_PYOPENSSL
370- ):
371- raise SkipTest (
372- "LibreSSL on OSX doesn't support setting CA certificates "
373- "using SSL_CERT_FILE environment variable."
374- )
375- self .original_certs = os .environ .get ("SSL_CERT_FILE" )
376- # Tell OpenSSL where CA certificates live.
377- os .environ ["SSL_CERT_FILE" ] = ca_certs
378-
379- def disable (self ):
380- if self .original_certs is None :
381- os .environ .pop ("SSL_CERT_FILE" )
382- else :
383- os .environ ["SSL_CERT_FILE" ] = self .original_certs
384-
385-
38639if _IS_SYNC :
38740 PARENT = threading .Thread
38841else :
0 commit comments