|
20 | 20 |
|
21 | 21 | sys.path[0:0] = [""] |
22 | 22 |
|
| 23 | +from pymongo_auth_aws import AwsCredential, auth |
| 24 | + |
23 | 25 | from pymongo import MongoClient |
24 | 26 | from pymongo.errors import OperationFailure |
25 | 27 | from pymongo.uri_parser import parse_uri |
@@ -53,6 +55,62 @@ def test_connect_uri(self): |
53 | 55 | with MongoClient(self.uri) as client: |
54 | 56 | client.get_database().test.find_one() |
55 | 57 |
|
| 58 | + def setup_cache(self): |
| 59 | + if os.environ.get("AWS_ACCESS_KEY_ID", None) or "@" in self.uri: |
| 60 | + self.skipTest("Not testing cached credentials") |
| 61 | + if not hasattr(auth, "set_cached_credentials"): |
| 62 | + self.skipTest("Cached credentials not available") |
| 63 | + |
| 64 | + # Ensure cleared credentials. |
| 65 | + auth.set_cached_credentials(None) |
| 66 | + self.assertEqual(auth.get_cached_credentials(), None) |
| 67 | + |
| 68 | + client = MongoClient(self.uri) |
| 69 | + client.get_database().test.find_one() |
| 70 | + client.close() |
| 71 | + return auth.get_cached_credentials() |
| 72 | + |
| 73 | + def test_cache_credentials(self): |
| 74 | + creds = self.setup_cache() |
| 75 | + self.assertIsNotNone(creds) |
| 76 | + |
| 77 | + def test_cache_about_to_expire(self): |
| 78 | + creds = self.setup_cache() |
| 79 | + client = MongoClient(self.uri) |
| 80 | + self.addCleanup(client.close) |
| 81 | + |
| 82 | + # Make the creds about to expire. |
| 83 | + creds = auth.get_cached_credentials() |
| 84 | + assert creds is not None |
| 85 | + |
| 86 | + creds = AwsCredential(creds.username, creds.password, creds.token, lambda x: True) |
| 87 | + auth.set_cached_credentials(creds) |
| 88 | + |
| 89 | + client.get_database().test.find_one() |
| 90 | + new_creds = auth.get_cached_credentials() |
| 91 | + self.assertNotEqual(creds, new_creds) |
| 92 | + |
| 93 | + def test_poisoned_cache(self): |
| 94 | + creds = self.setup_cache() |
| 95 | + |
| 96 | + client = MongoClient(self.uri) |
| 97 | + self.addCleanup(client.close) |
| 98 | + |
| 99 | + # Poison the creds with invalid password. |
| 100 | + assert creds is not None |
| 101 | + creds = AwsCredential("a" * 24, "b" * 24, "c" * 24) |
| 102 | + auth.set_cached_credentials(creds) |
| 103 | + |
| 104 | + with self.assertRaises(OperationFailure): |
| 105 | + client.get_database().test.find_one() |
| 106 | + |
| 107 | + # Make sure the cache was cleared. |
| 108 | + self.assertEqual(auth.get_cached_credentials(), None) |
| 109 | + |
| 110 | + # The next attempt should generate a new cred and succeed. |
| 111 | + client.get_database().test.find_one() |
| 112 | + self.assertNotEqual(auth.get_cached_credentials(), None) |
| 113 | + |
56 | 114 |
|
57 | 115 | class TestAWSLambdaExamples(unittest.TestCase): |
58 | 116 | def test_shared_client(self): |
|
0 commit comments