Skip to content

Commit 60fedc9

Browse files
committed
Add SSL rotation support for SQL Server
1 parent 6963d9e commit 60fedc9

File tree

2 files changed

+232
-71
lines changed

2 files changed

+232
-71
lines changed

SecretsManagerRDSSQLServerRotationMultiUser/lambda_function.py

Lines changed: 145 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def lambda_handler(event, context):
1717
This handler uses the master-user rotation scheme to rotate an RDS SQL Server user credential. During the first rotation, this
1818
scheme logs into the database as the master user, creates a new user (appending _clone to the username), and grants the
1919
new user all of the permissions from the user being rotated. Once the secret is in this state, every subsequent rotation
20-
simply creates a new secret with the AWSPREVIOUS user credentials, adds any missing permissions that are in the current
21-
secret, changes that user's password, and then marks the latest secret as AWSCURRENT.
20+
simply creates a new secret with the AWSPREVIOUS user credentials, changes that user's password, and then marks the
21+
latest secret as AWSCURRENT.
2222
2323
The Secret SecretString is expected to be a JSON string with the following format:
2424
{
@@ -166,7 +166,7 @@ def set_secret(service_client, arn, token):
166166
if get_alt_username(current_dict['username']) != pending_dict['username']:
167167
logger.error("setSecret: Attempting to modify user %s other than current user or clone %s" % (pending_dict['username'], current_dict['username']))
168168
raise ValueError("Attempting to modify user %s other than current user or clone %s" % (pending_dict['username'], current_dict['username']))
169-
169+
170170
# Make sure the host from current and pending match
171171
if current_dict['host'] != pending_dict['host']:
172172
logger.error("setSecret: Attempting to modify user for host %s other than current host %s" % (pending_dict['host'], current_dict['host']))
@@ -205,7 +205,7 @@ def set_secret(service_client, arn, token):
205205

206206
# Determine if we are in a contained DB
207207
containment = 0
208-
if not version.startswith("Microsoft SQL Server 2008"): # SQL Server 2008 does not support contained databases
208+
if not version.startswith("Microsoft SQL Server 2008"): # SQL Server 2008 does not support contained databases
209209
cursor.execute("SELECT containment FROM sys.databases WHERE name = %s", current_db)
210210
containment = cursor.fetchall()[0]['containment']
211211

@@ -216,7 +216,7 @@ def set_secret(service_client, arn, token):
216216
set_password_for_user(cursor, current_dict['username'], pending_dict)
217217

218218
conn.commit()
219-
logger.info("setSecret: Successfully created user %s in SQL Server DB for secret arn %s." % (pending_dict['username'], arn))
219+
logger.info("setSecret: Successfully set password for %s in SQL Server DB for secret arn %s." % (pending_dict['username'], arn))
220220
finally:
221221
conn.close()
222222

@@ -294,10 +294,11 @@ def finish_secret(service_client, arn, token):
294294

295295

296296
def get_connection(secret_dict):
297-
"""Gets a connection to SQL Server DB from a secret dictionary
297+
"""Gets a connection to a SQL Server DB from a secret dictionary
298298
299-
This helper function tries to connect to the database grabbing connection info
300-
from the secret dictionary. If successful, it returns the connection, else None
299+
This helper function uses connectivity information from the secret dictionary to initiate
300+
connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
301+
initial connection fails using SSL and fall_back is True.
301302
302303
Args:
303304
secret_dict (dict): The Secret Dictionary
@@ -313,6 +314,81 @@ def get_connection(secret_dict):
313314
port = str(secret_dict['port']) if 'port' in secret_dict else '1433'
314315
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else 'master'
315316

317+
# Get SSL connectivity configuration
318+
use_ssl, fall_back = get_ssl_config(secret_dict)
319+
320+
# if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
321+
conn = connect_and_authenticate(secret_dict, port, dbname, use_ssl)
322+
if conn or not fall_back:
323+
return conn
324+
else:
325+
return connect_and_authenticate(secret_dict, port, dbname, False)
326+
327+
328+
def get_ssl_config(secret_dict):
329+
"""Gets the desired SSL and fall back behavior using a secret dictionary
330+
331+
This helper function uses the existance and value the 'ssl' key in a secret dictionary
332+
to determine desired SSL connectivity configuration. Its behavior is as follows:
333+
- 'ssl' key DNE or invalid type/value: return True, True
334+
- 'ssl' key is bool: return secret_dict['ssl'], False
335+
- 'ssl' key equals "true" ignoring case: return True, False
336+
- 'ssl' key equals "false" ignoring case: return False, False
337+
338+
Args:
339+
secret_dict (dict): The Secret Dictionary
340+
341+
Returns:
342+
Tuple(use_ssl, fall_back): SSL configuration
343+
- use_ssl (bool): Flag indicating if an SSL connection should be attempted
344+
- fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
345+
346+
"""
347+
# Default to True for SSL and fall_back mode if 'ssl' key DNE
348+
if 'ssl' not in secret_dict:
349+
return True, True
350+
351+
# Handle type bool
352+
if isinstance(secret_dict['ssl'], bool):
353+
return secret_dict['ssl'], False
354+
355+
# Handle type string
356+
if isinstance(secret_dict['ssl'], str):
357+
ssl = secret_dict['ssl'].lower()
358+
if ssl == "true":
359+
return True, False
360+
elif ssl == "false":
361+
return False, False
362+
else:
363+
# Invalid string value, default to True for both SSL and fall_back mode
364+
return True, True
365+
366+
# Invalid type, default to True for both SSL and fall_back mode
367+
return True, True
368+
369+
370+
def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
371+
"""Attempt to connect and authenticate to a SQL Server DB
372+
373+
This helper function tries to connect to the database using connectivity info passed in.
374+
If successful, it returns the connection, else None
375+
376+
Args:
377+
- secret_dict (dict): The Secret Dictionary
378+
- port (int): The databse port to connect to
379+
- dbname (str): Name of the database
380+
- use_ssl (bool): Flag indicating whether connection should use SSL/TLS
381+
382+
Returns:
383+
Connection: The pymssql.Connection object if successful. None otherwise
384+
385+
Raises:
386+
KeyError: If the secret json does not contain the expected keys
387+
388+
"""
389+
# Dynamically set tds configuration based on ssl flag
390+
os.environ['FREETDSCONF'] = '/var/task/%s' % 'freetds_ssl.conf' if use_ssl else 'freetds.conf'
391+
316392
# Try to obtain a connection to the db
317393
try:
318394
conn = pymssql.connect(server=secret_dict['host'],
@@ -431,8 +507,8 @@ def set_password_for_login(cursor, current_db, current_login, pending_dict):
431507
# Only handle server level permissions if we are connected the the master DB
432508
if current_db == 'master':
433509
# Loop through the types of server permissions and grant them to the new login
434-
query = "SELECT state_desc, permission_name FROM sys.server_permissions perm "\
435-
"JOIN sys.server_principals prin ON perm.grantee_principal_id = prin.principal_id "\
510+
query = "SELECT state_desc, permission_name FROM sys.server_permissions perm " \
511+
"JOIN sys.server_principals prin ON perm.grantee_principal_id = prin.principal_id " \
436512
"WHERE prin.name = %s"
437513
cursor.execute(query, current_login)
438514
for row in cursor.fetchall():
@@ -444,7 +520,9 @@ def set_password_for_login(cursor, current_db, current_login, pending_dict):
444520
# We do not create user objects in the master database
445521
else:
446522
# Get the user for the current login and generate the alt user
447-
cursor.execute("SELECT dbprin.name FROM sys.database_principals dbprin JOIN sys.server_principals sprin ON dbprin.sid = sprin.sid WHERE sprin.name = %s", current_login)
523+
cursor.execute(
524+
"SELECT dbprin.name FROM sys.database_principals dbprin JOIN sys.server_principals sprin ON dbprin.sid = sprin.sid WHERE sprin.name = %s",
525+
current_login)
448526
cur_user = cursor.fetchall()[0]['name']
449527
alt_user = get_alt_username(cur_user)
450528

@@ -516,9 +594,9 @@ def apply_database_permissions(cursor, current_user, pending_user):
516594
517595
"""
518596
# Get the roles assigned to the current user and assign it to the pending user
519-
query = "SELECT roleprin.name FROM sys.database_role_members rolemems "\
520-
"JOIN sys.database_principals roleprin ON roleprin.principal_id = rolemems.role_principal_id "\
521-
"JOIN sys.database_principals userprin ON userprin.principal_id = rolemems.member_principal_id "\
597+
query = "SELECT roleprin.name FROM sys.database_role_members rolemems " \
598+
"JOIN sys.database_principals roleprin ON roleprin.principal_id = rolemems.role_principal_id " \
599+
"JOIN sys.database_principals userprin ON userprin.principal_id = rolemems.member_principal_id " \
522600
"WHERE userprin.name = %s"
523601
cursor.execute(query, current_user)
524602
for row in cursor.fetchall():
@@ -528,69 +606,69 @@ def apply_database_permissions(cursor, current_user, pending_user):
528606
cursor.execute(sql_stmt)
529607

530608
# Loop through the database permissions and grant them to the user
531-
query = "SELECT "\
532-
"class = perm.class, "\
533-
"state_desc = perm.state_desc, "\
534-
"perm_name = perm.permission_name, "\
535-
"schema_name = permschem.name, "\
536-
"obj_name = obj.name, "\
537-
"obj_schema_name = objschem.name, "\
538-
"col_name = col.name, "\
539-
"imp_name = imp.name, "\
540-
"imp_type = imp.type, "\
541-
"assembly_name = assembly.name, "\
542-
"type_name = types.name, "\
543-
"type_schema = typeschem.name, "\
544-
"schema_coll_name = schema_coll.name, "\
545-
"xml_schema = xmlschem.name, "\
546-
"msg_type_name = msg_type.name, "\
547-
"contract_name = contract.name, "\
548-
"svc_name = svc.name, "\
549-
"binding_name = binding.name, "\
550-
"route_name = route.name, "\
551-
"catalog_name = catalog.name, "\
552-
"symkey_name = symkey.name, "\
553-
"cert_name = cert.name, "\
554-
"asymkey_name = asymkey.name "\
555-
"FROM sys.database_permissions perm "\
556-
"JOIN sys.database_principals prin ON perm.grantee_principal_id = prin.principal_id "\
557-
"LEFT JOIN sys.schemas permschem ON permschem.schema_id = perm.major_id "\
558-
"LEFT JOIN sys.objects obj ON obj.object_id = perm.major_id "\
559-
"LEFT JOIN sys.schemas objschem ON objschem.schema_id = obj.schema_id "\
560-
"LEFT JOIN sys.columns col ON col.object_id = perm.major_id AND col.column_id = perm.minor_id "\
561-
"LEFT JOIN sys.database_principals imp ON imp.principal_id = perm.major_id "\
562-
"LEFT JOIN sys.assemblies assembly ON assembly.assembly_id = perm.major_id "\
563-
"LEFT JOIN sys.types types ON types.user_type_id = perm.major_id "\
564-
"LEFT JOIN sys.schemas typeschem ON typeschem.schema_id = types.schema_id "\
565-
"LEFT JOIN sys.xml_schema_collections schema_coll ON schema_coll.xml_collection_id = perm.major_id "\
566-
"LEFT JOIN sys.schemas xmlschem ON xmlschem.schema_id = schema_coll.schema_id "\
567-
"LEFT JOIN sys.service_message_types msg_type ON msg_type.message_type_id = perm.major_id "\
568-
"LEFT JOIN sys.service_contracts contract ON contract.service_contract_id = perm.major_id "\
569-
"LEFT JOIN sys.services svc ON svc.service_id = perm.major_id "\
570-
"LEFT JOIN sys.remote_service_bindings binding ON binding.remote_service_binding_id = perm.major_id "\
571-
"LEFT JOIN sys.routes route ON route.route_id = perm.major_id "\
572-
"LEFT JOIN sys.fulltext_catalogs catalog ON catalog.fulltext_catalog_id = perm.major_id "\
573-
"LEFT JOIN sys.symmetric_keys symkey ON symkey.symmetric_key_id = perm.major_id "\
574-
"LEFT JOIN sys.certificates cert ON cert.certificate_id = perm.major_id "\
575-
"LEFT JOIN sys.asymmetric_keys asymkey ON asymkey.asymmetric_key_id = perm.major_id "\
609+
query = "SELECT " \
610+
"class = perm.class, " \
611+
"state_desc = perm.state_desc, " \
612+
"perm_name = perm.permission_name, " \
613+
"schema_name = permschem.name, " \
614+
"obj_name = obj.name, " \
615+
"obj_schema_name = objschem.name, " \
616+
"col_name = col.name, " \
617+
"imp_name = imp.name, " \
618+
"imp_type = imp.type, " \
619+
"assembly_name = assembly.name, " \
620+
"type_name = types.name, " \
621+
"type_schema = typeschem.name, " \
622+
"schema_coll_name = schema_coll.name, " \
623+
"xml_schema = xmlschem.name, " \
624+
"msg_type_name = msg_type.name, " \
625+
"contract_name = contract.name, " \
626+
"svc_name = svc.name, " \
627+
"binding_name = binding.name, " \
628+
"route_name = route.name, " \
629+
"catalog_name = catalog.name, " \
630+
"symkey_name = symkey.name, " \
631+
"cert_name = cert.name, " \
632+
"asymkey_name = asymkey.name " \
633+
"FROM sys.database_permissions perm " \
634+
"JOIN sys.database_principals prin ON perm.grantee_principal_id = prin.principal_id " \
635+
"LEFT JOIN sys.schemas permschem ON permschem.schema_id = perm.major_id " \
636+
"LEFT JOIN sys.objects obj ON obj.object_id = perm.major_id " \
637+
"LEFT JOIN sys.schemas objschem ON objschem.schema_id = obj.schema_id " \
638+
"LEFT JOIN sys.columns col ON col.object_id = perm.major_id AND col.column_id = perm.minor_id " \
639+
"LEFT JOIN sys.database_principals imp ON imp.principal_id = perm.major_id " \
640+
"LEFT JOIN sys.assemblies assembly ON assembly.assembly_id = perm.major_id " \
641+
"LEFT JOIN sys.types types ON types.user_type_id = perm.major_id " \
642+
"LEFT JOIN sys.schemas typeschem ON typeschem.schema_id = types.schema_id " \
643+
"LEFT JOIN sys.xml_schema_collections schema_coll ON schema_coll.xml_collection_id = perm.major_id " \
644+
"LEFT JOIN sys.schemas xmlschem ON xmlschem.schema_id = schema_coll.schema_id " \
645+
"LEFT JOIN sys.service_message_types msg_type ON msg_type.message_type_id = perm.major_id " \
646+
"LEFT JOIN sys.service_contracts contract ON contract.service_contract_id = perm.major_id " \
647+
"LEFT JOIN sys.services svc ON svc.service_id = perm.major_id " \
648+
"LEFT JOIN sys.remote_service_bindings binding ON binding.remote_service_binding_id = perm.major_id " \
649+
"LEFT JOIN sys.routes route ON route.route_id = perm.major_id " \
650+
"LEFT JOIN sys.fulltext_catalogs catalog ON catalog.fulltext_catalog_id = perm.major_id " \
651+
"LEFT JOIN sys.symmetric_keys symkey ON symkey.symmetric_key_id = perm.major_id " \
652+
"LEFT JOIN sys.certificates cert ON cert.certificate_id = perm.major_id " \
653+
"LEFT JOIN sys.asymmetric_keys asymkey ON asymkey.asymmetric_key_id = perm.major_id " \
576654
"WHERE prin.name = %s"
577655
cursor.execute(query, current_user)
578656
for row in cursor.fetchall():
579657
# Determine which type of permission this is and create the sql statement accordingly
580-
if row['class'] == 0: # Database permission
658+
if row['class'] == 0: # Database permission
581659
permission = row['perm_name']
582-
elif row['class'] == 1: # Object or Column
660+
elif row['class'] == 1: # Object or Column
583661
permission = "%s ON OBJECT::%s.%s" % (row['perm_name'], row['obj_schema_name'], row['obj_name'])
584662
if row['col_name']:
585663
permission = "%s (%s) " % (permission, row['col_name'])
586-
elif row['class'] == 3: # Schema
664+
elif row['class'] == 3: # Schema
587665
permission = "%s ON SCHEMA::%s" % (row['perm_name'], row['schema_name'])
588-
elif row['class'] == 4: # Impersonation (Database Principal)
589-
if row['imp_type'] == 'S': # SQL User
666+
elif row['class'] == 4: # Impersonation (Database Principal)
667+
if row['imp_type'] == 'S': # SQL User
590668
permission = "%s ON USER::%s" % (row['perm_name'], row['imp_name'])
591-
elif row['imp_type'] == 'R': # Role
669+
elif row['imp_type'] == 'R': # Role
592670
permission = "%s ON ROLE::%s" % (row['perm_name'], row['imp_name'])
593-
elif row['imp_type'] == 'A': # Application Role
671+
elif row['imp_type'] == 'A': # Application Role
594672
permission = "%s ON APPLICATION ROLE::%s" % (row['perm_name'], row['imp_name'])
595673
else:
596674
raise ValueError("Invalid database principal permission type %s" % row['imp_type'])
@@ -630,6 +708,7 @@ def apply_database_permissions(cursor, current_user, pending_user):
630708
# Execute the sql
631709
cursor.execute(sql_stmt)
632710

711+
633712
def is_rds_replica_database(replica_dict, master_dict):
634713
"""Validates that the database of a secret is a replica of the database of the master secret
635714

0 commit comments

Comments
 (0)