Skip to content

Commit 6963d9e

Browse files
committed
Add SSL rotation support for PostgreSQL
1 parent 9ea89c9 commit 6963d9e

File tree

2 files changed

+190
-24
lines changed

2 files changed

+190
-24
lines changed

SecretsManagerRDSPostgreSQLRotationMultiUser/lambda_function.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def lambda_handler(event, context):
1818
This handler uses the master-user rotation scheme to rotate an RDS PostgreSQL user credential. During the first rotation, this
1919
scheme logs into the database as the master user, creates a new user (appending _clone to the username), and grants the
2020
new user all of the permissions from the user being rotated. Once the secret is in this state, every subsequent rotation
21-
simply creates a new secret with the AWSPREVIOUS user credentials, adds any missing permissions that are in the current
22-
secret, changes that user's password, and then marks the latest secret as AWSCURRENT.
21+
simply creates a new secret with the AWSPREVIOUS user credentials, changes that user's password, and then marks the
22+
latest secret as AWSCURRENT.
2323
2424
The Secret SecretString is expected to be a JSON string with the following format:
2525
{
@@ -165,8 +165,10 @@ def set_secret(service_client, arn, token):
165165

166166
# Make sure the user from current and pending match
167167
if get_alt_username(current_dict['username']) != pending_dict['username']:
168-
logger.error("setSecret: Attempting to modify user %s other than current user clone %s" % (pending_dict['username'], get_alt_username(current_dict['username'])))
169-
raise ValueError("Attempting to modify user %s other than current user clone %s" % (pending_dict['username'], get_alt_username(current_dict['username'])))
168+
logger.error(
169+
"setSecret: Attempting to modify user %s other than current user or clone %s" % (pending_dict['username'], current_dict['username']))
170+
raise ValueError(
171+
"Attempting to modify user %s other than current user or clone %s" % (pending_dict['username'], current_dict['username']))
170172

171173
# Make sure the host from current and pending match
172174
if current_dict['host'] != pending_dict['host']:
@@ -216,7 +218,7 @@ def set_secret(service_client, arn, token):
216218
cur.execute(alter_role + " WITH PASSWORD %s", (pending_dict['password'],))
217219

218220
conn.commit()
219-
logger.info("setSecret: Successfully created user %s in PostgreSQL DB for secret arn %s." % (pending_dict['username'], arn))
221+
logger.info("setSecret: Successfully set password for %s in PostgreSQL DB for secret arn %s." % (pending_dict['username'], arn))
220222
finally:
221223
conn.close()
222224

@@ -297,8 +299,9 @@ def finish_secret(service_client, arn, token):
297299
def get_connection(secret_dict):
298300
"""Gets a connection to PostgreSQL DB from a secret dictionary
299301
300-
This helper function tries to connect to the database grabbing connection info
301-
from the secret dictionary. If successful, it returns the connection, else None
302+
This helper function uses connectivity information from the secret dictionary to initiate
303+
connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
304+
initial connection fails using SSL and fall_back is True.
302305
303306
Args:
304307
secret_dict (dict): The Secret Dictionary
@@ -314,9 +317,87 @@ def get_connection(secret_dict):
314317
port = int(secret_dict['port']) if 'port' in secret_dict else 5432
315318
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else "postgres"
316319

320+
# Get SSL connectivity configuration
321+
use_ssl, fall_back = get_ssl_config(secret_dict)
322+
323+
# if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
324+
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
325+
if conn or not fall_back:
326+
return conn
327+
else:
328+
return connect_and_authenticate(secret_dict, port, dbname, False)
329+
330+
331+
def get_ssl_config(secret_dict):
332+
"""Gets the desired SSL and fall back behavior using a secret dictionary
333+
334+
This helper function uses the existance and value the 'ssl' key in a secret dictionary
335+
to determine desired SSL connectivity configuration. Its behavior is as follows:
336+
- 'ssl' key DNE or invalid type/value: return True, True
337+
- 'ssl' key is bool: return secret_dict['ssl'], False
338+
- 'ssl' key equals "true" ignoring case: return True, False
339+
- 'ssl' key equals "false" ignoring case: return False, False
340+
341+
Args:
342+
secret_dict (dict): The Secret Dictionary
343+
344+
Returns:
345+
Tuple(use_ssl, fall_back): SSL configuration
346+
- use_ssl (bool): Flag indicating if an SSL connection should be attempted
347+
- fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
348+
349+
"""
350+
# Default to True for SSL and fall_back mode if 'ssl' key DNE
351+
if 'ssl' not in secret_dict:
352+
return True, True
353+
354+
# Handle type bool
355+
if isinstance(secret_dict['ssl'], bool):
356+
return secret_dict['ssl'], False
357+
358+
# Handle type string
359+
if isinstance(secret_dict['ssl'], str):
360+
ssl = secret_dict['ssl'].lower()
361+
if ssl == "true":
362+
return True, False
363+
elif ssl == "false":
364+
return False, False
365+
else:
366+
# Invalid string value, default to True for both SSL and fall_back mode
367+
return True, True
368+
369+
# Invalid type, default to True for both SSL and fall_back mode
370+
return True, True
371+
372+
373+
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
374+
"""Attempt to connect and authenticate to a PostgreSQL instance
375+
376+
This helper function tries to connect to the database using connectivity info passed in.
377+
If successful, it returns the connection, else None
378+
379+
Args:
380+
- secret_dict (dict): The Secret Dictionary
381+
- port (int): The databse port to connect to
382+
- dbname (str): Name of the database
383+
- use_ssl (bool): Flag indicating whether connection should use SSL/TLS
384+
385+
Returns:
386+
Connection: The pymongo.database.Database object if successful. None otherwise
387+
388+
Raises:
389+
KeyError: If the secret json does not contain the expected keys
390+
391+
"""
317392
# Try to obtain a connection to the db
318393
try:
319-
conn = pgdb.connect(host=secret_dict['host'], user=secret_dict['username'], password=secret_dict['password'], database=dbname, port=port, connect_timeout=5)
394+
if use_ssl:
395+
# Setting sslmode='verify-full' will verify the server's certificate and check the server's host name
396+
conn = pgdb.connect(host=secret_dict['host'], user=secret_dict['username'], password=secret_dict['password'], database=dbname, port=port,
397+
connect_timeout=5, sslrootcert='/etc/pki/tls/cert.pem', sslmode='verify-full')
398+
else:
399+
conn = pgdb.connect(host=secret_dict['host'], user=secret_dict['username'], password=secret_dict['password'], database=dbname, port=port,
400+
connect_timeout=5, sslmode='disable')
320401
return conn
321402
except pg.InternalError:
322403
return None
@@ -392,6 +473,7 @@ def get_alt_username(current_username):
392473
raise ValueError("Unable to clone user, username length with _clone appended would exceed 63 characters")
393474
return new_username
394475

476+
395477
def is_rds_replica_database(replica_dict, master_dict):
396478
"""Validates that the database of a secret is a replica of the database of the master secret
397479
@@ -429,4 +511,4 @@ def is_rds_replica_database(replica_dict, master_dict):
429511

430512
# DB Instance identifiers are unique - can only be one result
431513
current_instance = instances[0]
432-
return master_instance_id == current_instance.get('ReadReplicaSourceDBInstanceIdentifier')
514+
return master_instance_id == current_instance.get('ReadReplicaSourceDBInstanceIdentifier')

SecretsManagerRDSPostgreSQLRotationSingleUser/lambda_function.py

Lines changed: 99 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,23 @@ def set_secret(service_client, arn, token):
172172

173173
# Now try the current password
174174
conn = get_connection(current_dict)
175-
if not conn:
176-
if previous_dict:
177-
# If both current and pending do not work, try previous
178-
conn = get_connection(previous_dict)
179-
180-
# Make sure the user/host from previous and pending match
181-
if previous_dict['username'] != pending_dict['username']:
182-
logger.error("setSecret: Attempting to modify user %s other than previous valid user %s" % (pending_dict['username'], previous_dict['username']))
183-
raise ValueError("Attempting to modify user %s other than previous valid user %s" % (pending_dict['username'], previous_dict['username']))
184-
if previous_dict['host'] != pending_dict['host']:
185-
logger.error("setSecret: Attempting to modify user for host %s other than previous valid host %s" % (pending_dict['host'], previous_dict['host']))
186-
raise ValueError("Attempting to modify user for host %s other than current previous valid %s" % (pending_dict['host'], previous_dict['host']))
175+
176+
# If both current and pending do not work, try previous
177+
if not conn and previous_dict:
178+
# Update previous_dict to leverage current SSL settings
179+
previous_dict.pop('ssl', None)
180+
if 'ssl' in current_dict:
181+
previous_dict['ssl'] = current_dict['ssl']
182+
183+
conn = get_connection(previous_dict)
184+
185+
# Make sure the user/host from previous and pending match
186+
if previous_dict['username'] != pending_dict['username']:
187+
logger.error("setSecret: Attempting to modify user %s other than previous valid user %s" % (pending_dict['username'], previous_dict['username']))
188+
raise ValueError("Attempting to modify user %s other than previous valid user %s" % (pending_dict['username'], previous_dict['username']))
189+
if previous_dict['host'] != pending_dict['host']:
190+
logger.error("setSecret: Attempting to modify user for host %s other than previous valid host %s" % (pending_dict['host'], previous_dict['host']))
191+
raise ValueError("Attempting to modify user for host %s other than current previous valid %s" % (pending_dict['host'], previous_dict['host']))
187192

188193
# If we still don't have a connection, raise a ValueError
189194
if not conn:
@@ -278,8 +283,9 @@ def finish_secret(service_client, arn, token):
278283
def get_connection(secret_dict):
279284
"""Gets a connection to PostgreSQL DB from a secret dictionary
280285
281-
This helper function tries to connect to the database grabbing connection info
282-
from the secret dictionary. If successful, it returns the connection, else None
286+
This helper function uses connectivity information from the secret dictionary to initiate
287+
connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
288+
initial connection fails using SSL and fall_back is True.
283289
284290
Args:
285291
secret_dict (dict): The Secret Dictionary
@@ -295,9 +301,87 @@ def get_connection(secret_dict):
295301
port = int(secret_dict['port']) if 'port' in secret_dict else 5432
296302
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else "postgres"
297303

304+
# Get SSL connectivity configuration
305+
use_ssl, fall_back = get_ssl_config(secret_dict)
306+
307+
# if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
308+
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
309+
if conn or not fall_back:
310+
return conn
311+
else:
312+
return connect_and_authenticate(secret_dict, port, dbname, False)
313+
314+
315+
def get_ssl_config(secret_dict):
316+
"""Gets the desired SSL and fall back behavior using a secret dictionary
317+
318+
This helper function uses the existance and value the 'ssl' key in a secret dictionary
319+
to determine desired SSL connectivity configuration. Its behavior is as follows:
320+
- 'ssl' key DNE or invalid type/value: return True, True
321+
- 'ssl' key is bool: return secret_dict['ssl'], False
322+
- 'ssl' key equals "true" ignoring case: return True, False
323+
- 'ssl' key equals "false" ignoring case: return False, False
324+
325+
Args:
326+
secret_dict (dict): The Secret Dictionary
327+
328+
Returns:
329+
Tuple(use_ssl, fall_back): SSL configuration
330+
- use_ssl (bool): Flag indicating if an SSL connection should be attempted
331+
- fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
332+
333+
"""
334+
# Default to True for SSL and fall_back mode if 'ssl' key DNE
335+
if 'ssl' not in secret_dict:
336+
return True, True
337+
338+
# Handle type bool
339+
if isinstance(secret_dict['ssl'], bool):
340+
return secret_dict['ssl'], False
341+
342+
# Handle type string
343+
if isinstance(secret_dict['ssl'], str):
344+
ssl = secret_dict['ssl'].lower()
345+
if ssl == "true":
346+
return True, False
347+
elif ssl == "false":
348+
return False, False
349+
else:
350+
# Invalid string value, default to True for both SSL and fall_back mode
351+
return True, True
352+
353+
# Invalid type, default to True for both SSL and fall_back mode
354+
return True, True
355+
356+
357+
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
358+
"""Attempt to connect and authenticate to a PostgreSQL instance
359+
360+
This helper function tries to connect to the database using connectivity info passed in.
361+
If successful, it returns the connection, else None
362+
363+
Args:
364+
- secret_dict (dict): The Secret Dictionary
365+
- port (int): The databse port to connect to
366+
- dbname (str): Name of the database
367+
- use_ssl (bool): Flag indicating whether connection should use SSL/TLS
368+
369+
Returns:
370+
Connection: The pymongo.database.Database object if successful. None otherwise
371+
372+
Raises:
373+
KeyError: If the secret json does not contain the expected keys
374+
375+
"""
298376
# Try to obtain a connection to the db
299377
try:
300-
conn = pgdb.connect(host=secret_dict['host'], user=secret_dict['username'], password=secret_dict['password'], database=dbname, port=port, connect_timeout=5)
378+
if use_ssl:
379+
# Setting sslmode='verify-full' will verify the server's certificate and check the server's host name
380+
conn = pgdb.connect(host=secret_dict['host'], user=secret_dict['username'], password=secret_dict['password'], database=dbname, port=port,
381+
connect_timeout=5, sslrootcert='/etc/pki/tls/cert.pem', sslmode='verify-full')
382+
else:
383+
conn = pgdb.connect(host=secret_dict['host'], user=secret_dict['username'], password=secret_dict['password'], database=dbname, port=port,
384+
connect_timeout=5, sslmode='disable')
301385
return conn
302386
except pg.InternalError:
303387
return None

0 commit comments

Comments
 (0)