|
17 | 17 | import os |
18 | 18 | import sys |
19 | 19 | import unittest |
| 20 | +from unittest.mock import patch |
20 | 21 |
|
21 | 22 | sys.path[0:0] = [""] |
22 | 23 |
|
@@ -111,6 +112,63 @@ def test_poisoned_cache(self): |
111 | 112 | client.get_database().test.find_one() |
112 | 113 | self.assertNotEqual(auth.get_cached_credentials(), None) |
113 | 114 |
|
| 115 | + def test_environment_variables_ignored(self): |
| 116 | + creds = self.setup_cache() |
| 117 | + self.assertIsNotNone(creds) |
| 118 | + prev = os.environ.copy() |
| 119 | + |
| 120 | + client = MongoClient(self.uri) |
| 121 | + self.addCleanup(client.close) |
| 122 | + |
| 123 | + client.get_database().test.find_one() |
| 124 | + |
| 125 | + self.assertIsNotNone(auth.get_cached_credentials()) |
| 126 | + |
| 127 | + mock_env = dict( |
| 128 | + AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar", AWS_SESSION_TOKEN="baz" |
| 129 | + ) |
| 130 | + |
| 131 | + with patch.dict("os.environ", mock_env): |
| 132 | + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") |
| 133 | + client.get_database().test.find_one() |
| 134 | + |
| 135 | + auth.set_cached_credentials(None) |
| 136 | + |
| 137 | + client2 = MongoClient(self.uri) |
| 138 | + self.addCleanup(client2.close) |
| 139 | + |
| 140 | + with patch.dict("os.environ", mock_env): |
| 141 | + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") |
| 142 | + with self.assertRaises(OperationFailure): |
| 143 | + client2.get_database().test.find_one() |
| 144 | + |
| 145 | + def test_no_cache_environment_variables(self): |
| 146 | + creds = self.setup_cache() |
| 147 | + self.assertIsNotNone(creds) |
| 148 | + auth.set_cached_credentials(None) |
| 149 | + |
| 150 | + mock_env = dict(AWS_ACCESS_KEY_ID=creds.username, AWS_SECRET_ACCESS_KEY=creds.password) |
| 151 | + if creds.token: |
| 152 | + mock_env["AWS_SESSION_TOKEN"] = creds.token |
| 153 | + |
| 154 | + client = MongoClient(self.uri) |
| 155 | + self.addCleanup(client.close) |
| 156 | + |
| 157 | + with patch.dict(os.environ, mock_env): |
| 158 | + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], creds.username) |
| 159 | + client.get_database().test.find_one() |
| 160 | + |
| 161 | + self.assertIsNone(auth.get_cached_credentials()) |
| 162 | + |
| 163 | + mock_env["AWS_ACCESS_KEY_ID"] = "foo" |
| 164 | + |
| 165 | + client2 = MongoClient(self.uri) |
| 166 | + self.addCleanup(client2.close) |
| 167 | + |
| 168 | + with patch.dict("os.environ", mock_env), self.assertRaises(OperationFailure): |
| 169 | + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") |
| 170 | + client2.get_database().test.find_one() |
| 171 | + |
114 | 172 |
|
115 | 173 | class TestAWSLambdaExamples(unittest.TestCase): |
116 | 174 | def test_shared_client(self): |
|
0 commit comments