Skip to content

Commit ec9cfee

Browse files
authored
Merge pull request #76 from aws-samples/ssl
SSL Rotation Support
2 parents 28b5b77 + 60fedc9 commit ec9cfee

File tree

10 files changed

+976
-142
lines changed

10 files changed

+976
-142
lines changed

SecretsManagerMongoDBRotationMultiUser/lambda_function.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def lambda_handler(event, context):
1717
This handler uses the master-user rotation scheme to rotate an MongoDB user credential. During the first rotation, this
1818
scheme logs into the database as the master user, creates a new user (appending _clone to the username), and grants the
1919
new user all of the permissions from the user being rotated. Once the secret is in this state, every subsequent rotation
20-
simply creates a new secret with the AWSPREVIOUS user credentials, adds any missing permissions that are in the current
21-
secret, changes that user's password, and then marks the latest secret as AWSCURRENT.
20+
simply creates a new secret with the AWSPREVIOUS user credentials, changes that user's password, and then marks the
21+
latest secret as AWSCURRENT.
2222
2323
The Secret SecretString is expected to be a JSON string with the following format:
2424
{
@@ -284,8 +284,9 @@ def finish_secret(service_client, arn, token):
284284
def get_connection(secret_dict):
285285
"""Gets a connection to MongoDB from a secret dictionary
286286
287-
This helper function tries to connect to the database grabbing connection info
288-
from the secret dictionary. If successful, it returns the connection, else None
287+
This helper function uses connectivity information from the secret dictionary to initiate
288+
connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
289+
initial connection fails using SSL and fall_back is True.
289290
290291
Args:
291292
secret_dict (dict): The Secret Dictionary
@@ -297,18 +298,86 @@ def get_connection(secret_dict):
297298
KeyError: If the secret json does not contain the expected keys
298299
299300
"""
301+
# Parse and validate the secret JSON string
300302
port = int(secret_dict['port']) if 'port' in secret_dict else 27017
301303
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else "admin"
302-
ssl = False
303-
if 'ssl' in secret_dict:
304-
if type(secret_dict['ssl']) is bool:
305-
ssl = secret_dict['ssl']
304+
305+
# Get SSL connectivity configuration
306+
use_ssl, fall_back = get_ssl_config(secret_dict)
307+
308+
# if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
309+
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
310+
if conn or not fall_back:
311+
return conn
312+
else:
313+
return connect_and_authenticate(secret_dict, port, dbname, False)
314+
315+
316+
def get_ssl_config(secret_dict):
317+
"""Gets the desired SSL and fall back behavior using a secret dictionary
318+
319+
This helper function uses the existance and value the 'ssl' key in a secret dictionary
320+
to determine desired SSL connectivity configuration. Its behavior is as follows:
321+
- 'ssl' key DNE or invalid type/value: return True, True
322+
- 'ssl' key is bool: return secret_dict['ssl'], False
323+
- 'ssl' key equals "true" ignoring case: return True, False
324+
- 'ssl' key equals "false" ignoring case: return False, False
325+
326+
Args:
327+
secret_dict (dict): The Secret Dictionary
328+
329+
Returns:
330+
Tuple(use_ssl, fall_back): SSL configuration
331+
- use_ssl (bool): Flag indicating if an SSL connection should be attempted
332+
- fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
333+
334+
"""
335+
# Default to True for SSL and fall_back mode if 'ssl' key DNE
336+
if 'ssl' not in secret_dict:
337+
return True, True
338+
339+
# Handle type bool
340+
if isinstance(secret_dict['ssl'], bool):
341+
return secret_dict['ssl'], False
342+
343+
# Handle type string
344+
if isinstance(secret_dict['ssl'], str):
345+
ssl = secret_dict['ssl'].lower()
346+
if ssl == "true":
347+
return True, False
348+
elif ssl == "false":
349+
return False, False
306350
else:
307-
ssl = (secret_dict['ssl'].lower() == "true")
351+
# Invalid string value, default to True for both SSL and fall_back mode
352+
return True, True
308353

354+
# Invalid type, default to True for both SSL and fall_back mode
355+
return True, True
356+
357+
358+
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
359+
"""Attempt to connect and authenticate to a MongoDB instance
360+
361+
This helper function tries to connect to the database using connectivity info passed in.
362+
If successful, it returns the connection, else None
363+
364+
Args:
365+
- secret_dict (dict): The Secret Dictionary
366+
- port (int): The databse port to connect to
367+
- dbname (str): Name of the database
368+
- use_ssl (bool): Flag indicating whether connection should use SSL/TLS
369+
370+
Returns:
371+
Connection: The pymongo.database.Database object if successful. None otherwise
372+
373+
Raises:
374+
KeyError: If the secret json does not contain the expected keys
375+
376+
"""
309377
# Try to obtain a connection to the db
310378
try:
311-
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000, ssl=ssl)
379+
# Hostname verfification and server certificate validation enabled by default when ssl=True
380+
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000, ssl=use_ssl)
312381
db = client[dbname]
313382
db.authenticate(secret_dict['username'], secret_dict['password'])
314383
return db

SecretsManagerMongoDBRotationSingleUser/lambda_function.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def set_secret(service_client, arn, token):
152152
previous_dict = None
153153
current_dict = get_secret_dict(service_client, arn, "AWSCURRENT")
154154
pending_dict = get_secret_dict(service_client, arn, "AWSPENDING", token)
155-
155+
156156
# First try to login with the pending secret, if it succeeds, return
157157
conn = get_connection(pending_dict)
158158
if conn:
@@ -175,12 +175,13 @@ def set_secret(service_client, arn, token):
175175

176176
# If both current and pending do not work, try previous
177177
if not conn and previous_dict:
178-
# If both current and pending do not work, try previous
179178
# Update previous_dict to leverage current SSL settings
180179
previous_dict.pop('ssl', None)
181180
if 'ssl' in current_dict:
182181
previous_dict['ssl'] = current_dict['ssl']
183182

183+
conn = get_connection(previous_dict)
184+
184185
# Make sure the user/host from previous and pending match
185186
if previous_dict['username'] != pending_dict['username']:
186187
logger.error("setSecret: Attempting to modify user %s other than previous valid user %s" % (pending_dict['username'], previous_dict['username']))
@@ -277,8 +278,9 @@ def finish_secret(service_client, arn, token):
277278
def get_connection(secret_dict):
278279
"""Gets a connection to MongoDB from a secret dictionary
279280
280-
This helper function tries to connect to the database grabbing connection info
281-
from the secret dictionary. If successful, it returns the connection, else None
281+
This helper function uses connectivity information from the secret dictionary to initiate
282+
connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
283+
initial connection fails using SSL and fall_back is True.
282284
283285
Args:
284286
secret_dict (dict): The Secret Dictionary
@@ -293,16 +295,83 @@ def get_connection(secret_dict):
293295
# Parse and validate the secret JSON string
294296
port = int(secret_dict['port']) if 'port' in secret_dict else 27017
295297
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else "admin"
296-
ssl = False
297-
if 'ssl' in secret_dict:
298-
if type(secret_dict['ssl']) is bool:
299-
ssl = secret_dict['ssl']
298+
299+
# Get SSL connectivity configuration
300+
use_ssl, fall_back = get_ssl_config(secret_dict)
301+
302+
# if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
303+
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
304+
if conn or not fall_back:
305+
return conn
306+
else:
307+
return connect_and_authenticate(secret_dict, port, dbname, False)
308+
309+
310+
def get_ssl_config(secret_dict):
311+
"""Gets the desired SSL and fall back behavior using a secret dictionary
312+
313+
This helper function uses the existance and value the 'ssl' key in a secret dictionary
314+
to determine desired SSL connectivity configuration. Its behavior is as follows:
315+
- 'ssl' key DNE or invalid type/value: return True, True
316+
- 'ssl' key is bool: return secret_dict['ssl'], False
317+
- 'ssl' key equals "true" ignoring case: return True, False
318+
- 'ssl' key equals "false" ignoring case: return False, False
319+
320+
Args:
321+
secret_dict (dict): The Secret Dictionary
322+
323+
Returns:
324+
Tuple(use_ssl, fall_back): SSL configuration
325+
- use_ssl (bool): Flag indicating if an SSL connection should be attempted
326+
- fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
327+
328+
"""
329+
# Default to True for SSL and fall_back mode if 'ssl' key DNE
330+
if 'ssl' not in secret_dict:
331+
return True, True
332+
333+
# Handle type bool
334+
if isinstance(secret_dict['ssl'], bool):
335+
return secret_dict['ssl'], False
336+
337+
# Handle type string
338+
if isinstance(secret_dict['ssl'], str):
339+
ssl = secret_dict['ssl'].lower()
340+
if ssl == "true":
341+
return True, False
342+
elif ssl == "false":
343+
return False, False
300344
else:
301-
ssl = (secret_dict['ssl'].lower() == "true")
302-
345+
# Invalid string value, default to True for both SSL and fall_back mode
346+
return True, True
347+
348+
# Invalid type, default to True for both SSL and fall_back mode
349+
return True, True
350+
351+
352+
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
353+
"""Attempt to connect and authenticate to a MongoDB instance
354+
355+
This helper function tries to connect to the database using connectivity info passed in.
356+
If successful, it returns the connection, else None
357+
358+
Args:
359+
- secret_dict (dict): The Secret Dictionary
360+
- port (int): The databse port to connect to
361+
- dbname (str): Name of the database
362+
- use_ssl (bool): Flag indicating whether connection should use SSL/TLS
363+
364+
Returns:
365+
Connection: The pymongo.database.Database object if successful. None otherwise
366+
367+
Raises:
368+
KeyError: If the secret json does not contain the expected keys
369+
370+
"""
303371
# Try to obtain a connection to the db
304372
try:
305-
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000, ssl=ssl)
373+
# Hostname verfification and server certificate validation enabled by default when ssl=True
374+
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000, ssl=use_ssl)
306375
db = client[dbname]
307376
db.authenticate(secret_dict['username'], secret_dict['password'])
308377
return db

SecretsManagerRDSMariaDBRotationMultiUser/lambda_function.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def lambda_handler(event, context):
1717
This handler uses the master-user rotation scheme to rotate an RDS MariaDB user credential. During the first rotation, this
1818
scheme logs into the database as the master user, creates a new user (appending _clone to the username), and grants the
1919
new user all of the permissions from the user being rotated. Once the secret is in this state, every subsequent rotation
20-
simply creates a new secret with the AWSPREVIOUS user credentials, adds any missing permissions that are in the current
21-
secret, changes that user's password, and then marks the latest secret as AWSCURRENT.
20+
simply creates a new secret with the AWSPREVIOUS user credentials, changes that user's password, and then marks the
21+
latest secret as AWSCURRENT.
2222
2323
The Secret SecretString is expected to be a JSON string with the following format:
2424
{
@@ -154,7 +154,7 @@ def set_secret(service_client, arn, token):
154154
"""
155155
current_dict = get_secret_dict(service_client, arn, "AWSCURRENT")
156156
pending_dict = get_secret_dict(service_client, arn, "AWSPENDING", token)
157-
157+
158158
# First try to login with the pending secret, if it succeeds, return
159159
conn = get_connection(pending_dict)
160160
if conn:
@@ -202,8 +202,22 @@ def set_secret(service_client, arn, token):
202202
cur.execute("SHOW GRANTS FOR %s", current_dict['username'])
203203
for row in cur.fetchall():
204204
grant = row[0].split(' TO ')
205-
new_grant_escaped = grant[0].replace('%','%%') # % is a special character in Python format strings.
205+
new_grant_escaped = grant[0].replace('%', '%%') # % is a special character in Python format strings.
206206
cur.execute(new_grant_escaped + " TO %s IDENTIFIED BY %s", (pending_dict['username'], pending_dict['password']))
207+
208+
# Copy TLS options to the new user
209+
cur.execute("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s", current_dict['username'])
210+
tls_options = cur.fetchone()
211+
ssl_type = tls_options[0]
212+
if not ssl_type:
213+
cur.execute("ALTER USER %s@'%%' REQUIRE NONE", pending_dict['username'])
214+
elif ssl_type == "ANY":
215+
cur.execute("ALTER USER %s@'%%' REQUIRE SSL", pending_dict['username'])
216+
elif ssl_type == "X509":
217+
cur.execute("ALTER USER %s@'%%' REQUIRE X509", pending_dict['username'])
218+
else:
219+
cur.execute("ALTER USER %s@'%%' REQUIRE CIPHER %s AND ISSUER %s AND SUBJECT %s", (pending_dict['username'], tls_options[1], tls_options[2], tls_options[3]))
220+
207221
conn.commit()
208222
logger.info("setSecret: Successfully set password for %s in MariaDB DB for secret arn %s." % (pending_dict['username'], arn))
209223
finally:
@@ -286,8 +300,9 @@ def finish_secret(service_client, arn, token):
286300
def get_connection(secret_dict):
287301
"""Gets a connection to MariaDB DB from a secret dictionary
288302
289-
This helper function tries to connect to the database grabbing connection info
290-
from the secret dictionary. If successful, it returns the connection, else None
303+
This helper function uses connectivity information from the secret dictionary to initiate
304+
connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
305+
initial connection fails using SSL and fall_back is True.
291306
292307
Args:
293308
secret_dict (dict): The Secret Dictionary
@@ -299,12 +314,88 @@ def get_connection(secret_dict):
299314
KeyError: If the secret json does not contain the expected keys
300315
301316
"""
317+
# Parse and validate the secret JSON string
302318
port = int(secret_dict['port']) if 'port' in secret_dict else 3306
303319
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else None
304320

321+
# Get SSL connectivity configuration
322+
use_ssl, fall_back = get_ssl_config(secret_dict)
323+
324+
# if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
325+
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
326+
if conn or not fall_back:
327+
return conn
328+
else:
329+
return connect_and_authenticate(secret_dict, port, dbname, False)
330+
331+
332+
def get_ssl_config(secret_dict):
333+
"""Gets the desired SSL and fall back behavior using a secret dictionary
334+
335+
This helper function uses the existance and value the 'ssl' key in a secret dictionary
336+
to determine desired SSL connectivity configuration. Its behavior is as follows:
337+
- 'ssl' key DNE or invalid type/value: return True, True
338+
- 'ssl' key is bool: return secret_dict['ssl'], False
339+
- 'ssl' key equals "true" ignoring case: return True, False
340+
- 'ssl' key equals "false" ignoring case: return False, False
341+
342+
Args:
343+
secret_dict (dict): The Secret Dictionary
344+
345+
Returns:
346+
Tuple(use_ssl, fall_back): SSL configuration
347+
- use_ssl (bool): Flag indicating if an SSL connection should be attempted
348+
- fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
349+
350+
"""
351+
# Default to True for SSL and fall_back mode if 'ssl' key DNE
352+
if 'ssl' not in secret_dict:
353+
return True, True
354+
355+
# Handle type bool
356+
if isinstance(secret_dict['ssl'], bool):
357+
return secret_dict['ssl'], False
358+
359+
# Handle type string
360+
if isinstance(secret_dict['ssl'], str):
361+
ssl = secret_dict['ssl'].lower()
362+
if ssl == "true":
363+
return True, False
364+
elif ssl == "false":
365+
return False, False
366+
else:
367+
# Invalid string value, default to True for both SSL and fall_back mode
368+
return True, True
369+
370+
# Invalid type, default to True for both SSL and fall_back mode
371+
return True, True
372+
373+
374+
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
375+
"""Attempt to connect and authenticate to a MariaDB instance
376+
377+
This helper function tries to connect to the database using connectivity info passed in.
378+
If successful, it returns the connection, else None
379+
380+
Args:
381+
- secret_dict (dict): The Secret Dictionary
382+
- port (int): The databse port to connect to
383+
- dbname (str): Name of the database
384+
- use_ssl (bool): Flag indicating whether connection should use SSL/TLS
385+
386+
Returns:
387+
Connection: The pymongo.database.Database object if successful. None otherwise
388+
389+
Raises:
390+
KeyError: If the secret json does not contain the expected keys
391+
392+
"""
393+
ssl = {'ca': '/etc/pki/tls/cert.pem'} if use_ssl else None
394+
305395
# Try to obtain a connection to the db
306396
try:
307-
conn = pymysql.connect(secret_dict['host'], user=secret_dict['username'], passwd=secret_dict['password'], port=port, db=dbname, connect_timeout=5)
397+
# Checks hostname and verifies server certificate implictly when 'ca' key is in 'ssl' dictionary
398+
conn = pymysql.connect(secret_dict['host'], user=secret_dict['username'], passwd=secret_dict['password'], port=port, db=dbname, connect_timeout=5, ssl=ssl)
308399
return conn
309400
except pymysql.OperationalError:
310401
return None
@@ -378,6 +469,7 @@ def get_alt_username(current_username):
378469
raise ValueError("Unable to clone user, username length with _clone appended would exceed 80 characters")
379470
return new_username
380471

472+
381473
def is_rds_replica_database(replica_dict, master_dict):
382474
"""Validates that the database of a secret is a replica of the database of the master secret
383475

0 commit comments

Comments
 (0)