1818import sys
1919import unittest
2020
21+ from collections import defaultdict
22+
2123sys .path [0 :0 ] = ["" ]
2224
2325import pymongo
2426from pymongo .ssl_support import HAS_SNI
2527
26-
27- _REPL = os .environ .get ("ATLAS_REPL" )
28- _SHRD = os .environ .get ("ATLAS_SHRD" )
29- _FREE = os .environ .get ("ATLAS_FREE" )
30- _TLS11 = os .environ .get ("ATLAS_TLS11" )
31- _TLS12 = os .environ .get ("ATLAS_TLS12" )
32-
33-
34- def _connect (uri ):
28+ try :
29+ import dns
30+ HAS_DNS = True
31+ except ImportError :
32+ HAS_DNS = False
33+
34+
35+ URIS = {
36+ "ATLAS_REPL" : os .environ .get ("ATLAS_REPL" ),
37+ "ATLAS_SHRD" : os .environ .get ("ATLAS_SHRD" ),
38+ "ATLAS_FREE" : os .environ .get ("ATLAS_FREE" ),
39+ "ATLAS_TLS11" : os .environ .get ("ATLAS_TLS11" ),
40+ "ATLAS_TLS12" : os .environ .get ("ATLAS_TLS12" ),
41+ "ATLAS_SRV_REPL" : os .environ .get ("ATLAS_SRV_REPL" ),
42+ "ATLAS_SRV_SHRD" : os .environ .get ("ATLAS_SRV_SHRD" ),
43+ "ATLAS_SRV_FREE" : os .environ .get ("ATLAS_SRV_FREE" ),
44+ "ATLAS_SRV_TLS11" : os .environ .get ("ATLAS_SRV_TLS11" ),
45+ "ATLAS_SRV_TLS12" : os .environ .get ("ATLAS_SRV_TLS12" ),
46+ }
47+
48+ # Set this variable to true to run the SRV tests even when dnspython is not
49+ # installed.
50+ MUST_TEST_SRV = os .environ .get ("MUST_TEST_SRV" )
51+
52+
53+ def connect (uri ):
54+ if not uri :
55+ raise Exception ("Must set env variable to test." )
3556 client = pymongo .MongoClient (uri )
3657 # No TLS error
3758 client .admin .command ('ismaster' )
@@ -40,29 +61,58 @@ def _connect(uri):
4061
4162
4263class TestAtlasConnect (unittest .TestCase ):
43-
44- @classmethod
45- def setUpClass (cls ):
46- if not all ([_REPL , _SHRD , _FREE ]):
47- raise Exception (
48- "Must set ATLAS_REPL/SHRD/FREE env variables to test." )
64+ @unittest .skipUnless (HAS_SNI , 'Free tier requires SNI support' )
65+ def test_free_tier (self ):
66+ connect (URIS ['ATLAS_FREE' ])
4967
5068 def test_replica_set (self ):
51- _connect ( _REPL )
69+ connect ( URIS [ 'ATLAS_REPL' ] )
5270
5371 def test_sharded_cluster (self ):
54- _connect (_SHRD )
55-
56- def test_free_tier (self ):
57- if not HAS_SNI :
58- raise unittest .SkipTest ("Free tier requires SNI support." )
59- _connect (_FREE )
72+ connect (URIS ['ATLAS_SHRD' ])
6073
6174 def test_tls_11 (self ):
62- _connect ( _TLS11 )
75+ connect ( URIS [ 'ATLAS_TLS11' ] )
6376
6477 def test_tls_12 (self ):
65- _connect (_TLS12 )
78+ connect (URIS ['ATLAS_TLS12' ])
79+
80+ def connect_srv (self , uri ):
81+ connect (uri )
82+ self .assertIn ('mongodb+srv://' , uri )
83+
84+ @unittest .skipUnless (HAS_SNI , 'Free tier requires SNI support' )
85+ @unittest .skipUnless (HAS_DNS or MUST_TEST_SRV , 'SRV requires dnspython' )
86+ def test_srv_free_tier (self ):
87+ self .connect_srv (URIS ['ATLAS_SRV_FREE' ])
88+
89+ @unittest .skipUnless (HAS_DNS or MUST_TEST_SRV , 'SRV requires dnspython' )
90+ def test_srv_replica_set (self ):
91+ self .connect_srv (URIS ['ATLAS_SRV_REPL' ])
92+
93+ @unittest .skipUnless (HAS_DNS or MUST_TEST_SRV , 'SRV requires dnspython' )
94+ def test_srv_sharded_cluster (self ):
95+ self .connect_srv (URIS ['ATLAS_SRV_SHRD' ])
96+
97+ @unittest .skipUnless (HAS_DNS or MUST_TEST_SRV , 'SRV requires dnspython' )
98+ def test_srv_tls_11 (self ):
99+ self .connect_srv (URIS ['ATLAS_SRV_TLS11' ])
100+
101+ @unittest .skipUnless (HAS_DNS or MUST_TEST_SRV , 'SRV requires dnspython' )
102+ def test_srv_tls_12 (self ):
103+ self .connect_srv (URIS ['ATLAS_SRV_TLS12' ])
104+
105+ def test_uniqueness (self ):
106+ """Ensure that we don't accidentally duplicate the test URIs."""
107+ uri_to_names = defaultdict (list )
108+ for name , uri in URIS .items ():
109+ if uri :
110+ uri_to_names [uri ].append (name )
111+ duplicates = [names for names in uri_to_names .values ()
112+ if len (names ) > 1 ]
113+ self .assertFalse (duplicates , 'Error: the following env variables have '
114+ 'duplicate values: %s' % (duplicates ,))
115+
66116
67117
68118if __name__ == '__main__' :
0 commit comments