@@ -17,8 +17,8 @@ def lambda_handler(event, context):
1717 This handler uses the master-user rotation scheme to rotate an RDS SQL Server user credential. During the first rotation, this
1818 scheme logs into the database as the master user, creates a new user (appending _clone to the username), and grants the
1919 new user all of the permissions from the user being rotated. Once the secret is in this state, every subsequent rotation
20- simply creates a new secret with the AWSPREVIOUS user credentials, adds any missing permissions that are in the current
21- secret, changes that user's password, and then marks the latest secret as AWSCURRENT.
20+ simply creates a new secret with the AWSPREVIOUS user credentials, changes that user's password, and then marks the
21+ latest secret as AWSCURRENT.
2222
2323 The Secret SecretString is expected to be a JSON string with the following format:
2424 {
@@ -166,7 +166,7 @@ def set_secret(service_client, arn, token):
166166 if get_alt_username (current_dict ['username' ]) != pending_dict ['username' ]:
167167 logger .error ("setSecret: Attempting to modify user %s other than current user or clone %s" % (pending_dict ['username' ], current_dict ['username' ]))
168168 raise ValueError ("Attempting to modify user %s other than current user or clone %s" % (pending_dict ['username' ], current_dict ['username' ]))
169-
169+
170170 # Make sure the host from current and pending match
171171 if current_dict ['host' ] != pending_dict ['host' ]:
172172 logger .error ("setSecret: Attempting to modify user for host %s other than current host %s" % (pending_dict ['host' ], current_dict ['host' ]))
@@ -205,7 +205,7 @@ def set_secret(service_client, arn, token):
205205
206206 # Determine if we are in a contained DB
207207 containment = 0
208- if not version .startswith ("Microsoft SQL Server 2008" ): # SQL Server 2008 does not support contained databases
208+ if not version .startswith ("Microsoft SQL Server 2008" ): # SQL Server 2008 does not support contained databases
209209 cursor .execute ("SELECT containment FROM sys.databases WHERE name = %s" , current_db )
210210 containment = cursor .fetchall ()[0 ]['containment' ]
211211
@@ -216,7 +216,7 @@ def set_secret(service_client, arn, token):
216216 set_password_for_user (cursor , current_dict ['username' ], pending_dict )
217217
218218 conn .commit ()
219- logger .info ("setSecret: Successfully created user %s in SQL Server DB for secret arn %s." % (pending_dict ['username' ], arn ))
219+ logger .info ("setSecret: Successfully set password for %s in SQL Server DB for secret arn %s." % (pending_dict ['username' ], arn ))
220220 finally :
221221 conn .close ()
222222
@@ -294,10 +294,11 @@ def finish_secret(service_client, arn, token):
294294
295295
296296def get_connection (secret_dict ):
297- """Gets a connection to SQL Server DB from a secret dictionary
297+ """Gets a connection to a SQL Server DB from a secret dictionary
298298
299- This helper function tries to connect to the database grabbing connection info
300- from the secret dictionary. If successful, it returns the connection, else None
299+ This helper function uses connectivity information from the secret dictionary to initiate
300+ connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
301+ initial connection fails using SSL and fall_back is True.
301302
302303 Args:
303304 secret_dict (dict): The Secret Dictionary
@@ -313,6 +314,81 @@ def get_connection(secret_dict):
313314 port = str (secret_dict ['port' ]) if 'port' in secret_dict else '1433'
314315 dbname = secret_dict ['dbname' ] if 'dbname' in secret_dict else 'master'
315316
317+ # Get SSL connectivity configuration
318+ use_ssl , fall_back = get_ssl_config (secret_dict )
319+
320+ # if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
321+ conn = connect_and_authenticate (secret_dict , port , dbname , use_ssl )
322+ if conn or not fall_back :
323+ return conn
324+ else :
325+ return connect_and_authenticate (secret_dict , port , dbname , False )
326+
327+
328+ def get_ssl_config (secret_dict ):
329+ """Gets the desired SSL and fall back behavior using a secret dictionary
330+
331+ This helper function uses the existance and value the 'ssl' key in a secret dictionary
332+ to determine desired SSL connectivity configuration. Its behavior is as follows:
333+ - 'ssl' key DNE or invalid type/value: return True, True
334+ - 'ssl' key is bool: return secret_dict['ssl'], False
335+ - 'ssl' key equals "true" ignoring case: return True, False
336+ - 'ssl' key equals "false" ignoring case: return False, False
337+
338+ Args:
339+ secret_dict (dict): The Secret Dictionary
340+
341+ Returns:
342+ Tuple(use_ssl, fall_back): SSL configuration
343+ - use_ssl (bool): Flag indicating if an SSL connection should be attempted
344+ - fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
345+
346+ """
347+ # Default to True for SSL and fall_back mode if 'ssl' key DNE
348+ if 'ssl' not in secret_dict :
349+ return True , True
350+
351+ # Handle type bool
352+ if isinstance (secret_dict ['ssl' ], bool ):
353+ return secret_dict ['ssl' ], False
354+
355+ # Handle type string
356+ if isinstance (secret_dict ['ssl' ], str ):
357+ ssl = secret_dict ['ssl' ].lower ()
358+ if ssl == "true" :
359+ return True , False
360+ elif ssl == "false" :
361+ return False , False
362+ else :
363+ # Invalid string value, default to True for both SSL and fall_back mode
364+ return True , True
365+
366+ # Invalid type, default to True for both SSL and fall_back mode
367+ return True , True
368+
369+
370+ def connect_and_authenticate (secret_dict , port , dbname , use_ssl ):
371+ """Attempt to connect and authenticate to a SQL Server DB
372+
373+ This helper function tries to connect to the database using connectivity info passed in.
374+ If successful, it returns the connection, else None
375+
376+ Args:
377+ - secret_dict (dict): The Secret Dictionary
378+ - port (int): The databse port to connect to
379+ - dbname (str): Name of the database
380+ - use_ssl (bool): Flag indicating whether connection should use SSL/TLS
381+
382+ Returns:
383+ Connection: The pymssql.Connection object if successful. None otherwise
384+
385+ Raises:
386+ KeyError: If the secret json does not contain the expected keys
387+
388+ """
389+ # Dynamically set tds configuration based on ssl flag
390+ os .environ ['FREETDSCONF' ] = '/var/task/%s' % 'freetds_ssl.conf' if use_ssl else 'freetds.conf'
391+
316392 # Try to obtain a connection to the db
317393 try :
318394 conn = pymssql .connect (server = secret_dict ['host' ],
@@ -431,8 +507,8 @@ def set_password_for_login(cursor, current_db, current_login, pending_dict):
431507 # Only handle server level permissions if we are connected the the master DB
432508 if current_db == 'master' :
433509 # Loop through the types of server permissions and grant them to the new login
434- query = "SELECT state_desc, permission_name FROM sys.server_permissions perm " \
435- "JOIN sys.server_principals prin ON perm.grantee_principal_id = prin.principal_id " \
510+ query = "SELECT state_desc, permission_name FROM sys.server_permissions perm " \
511+ "JOIN sys.server_principals prin ON perm.grantee_principal_id = prin.principal_id " \
436512 "WHERE prin.name = %s"
437513 cursor .execute (query , current_login )
438514 for row in cursor .fetchall ():
@@ -444,7 +520,9 @@ def set_password_for_login(cursor, current_db, current_login, pending_dict):
444520 # We do not create user objects in the master database
445521 else :
446522 # Get the user for the current login and generate the alt user
447- cursor .execute ("SELECT dbprin.name FROM sys.database_principals dbprin JOIN sys.server_principals sprin ON dbprin.sid = sprin.sid WHERE sprin.name = %s" , current_login )
523+ cursor .execute (
524+ "SELECT dbprin.name FROM sys.database_principals dbprin JOIN sys.server_principals sprin ON dbprin.sid = sprin.sid WHERE sprin.name = %s" ,
525+ current_login )
448526 cur_user = cursor .fetchall ()[0 ]['name' ]
449527 alt_user = get_alt_username (cur_user )
450528
@@ -516,9 +594,9 @@ def apply_database_permissions(cursor, current_user, pending_user):
516594
517595 """
518596 # Get the roles assigned to the current user and assign it to the pending user
519- query = "SELECT roleprin.name FROM sys.database_role_members rolemems " \
520- "JOIN sys.database_principals roleprin ON roleprin.principal_id = rolemems.role_principal_id " \
521- "JOIN sys.database_principals userprin ON userprin.principal_id = rolemems.member_principal_id " \
597+ query = "SELECT roleprin.name FROM sys.database_role_members rolemems " \
598+ "JOIN sys.database_principals roleprin ON roleprin.principal_id = rolemems.role_principal_id " \
599+ "JOIN sys.database_principals userprin ON userprin.principal_id = rolemems.member_principal_id " \
522600 "WHERE userprin.name = %s"
523601 cursor .execute (query , current_user )
524602 for row in cursor .fetchall ():
@@ -528,69 +606,69 @@ def apply_database_permissions(cursor, current_user, pending_user):
528606 cursor .execute (sql_stmt )
529607
530608 # Loop through the database permissions and grant them to the user
531- query = "SELECT " \
532- "class = perm.class, " \
533- "state_desc = perm.state_desc, " \
534- "perm_name = perm.permission_name, " \
535- "schema_name = permschem.name, " \
536- "obj_name = obj.name, " \
537- "obj_schema_name = objschem.name, " \
538- "col_name = col.name, " \
539- "imp_name = imp.name, " \
540- "imp_type = imp.type, " \
541- "assembly_name = assembly.name, " \
542- "type_name = types.name, " \
543- "type_schema = typeschem.name, " \
544- "schema_coll_name = schema_coll.name, " \
545- "xml_schema = xmlschem.name, " \
546- "msg_type_name = msg_type.name, " \
547- "contract_name = contract.name, " \
548- "svc_name = svc.name, " \
549- "binding_name = binding.name, " \
550- "route_name = route.name, " \
551- "catalog_name = catalog.name, " \
552- "symkey_name = symkey.name, " \
553- "cert_name = cert.name, " \
554- "asymkey_name = asymkey.name " \
555- "FROM sys.database_permissions perm " \
556- "JOIN sys.database_principals prin ON perm.grantee_principal_id = prin.principal_id " \
557- "LEFT JOIN sys.schemas permschem ON permschem.schema_id = perm.major_id " \
558- "LEFT JOIN sys.objects obj ON obj.object_id = perm.major_id " \
559- "LEFT JOIN sys.schemas objschem ON objschem.schema_id = obj.schema_id " \
560- "LEFT JOIN sys.columns col ON col.object_id = perm.major_id AND col.column_id = perm.minor_id " \
561- "LEFT JOIN sys.database_principals imp ON imp.principal_id = perm.major_id " \
562- "LEFT JOIN sys.assemblies assembly ON assembly.assembly_id = perm.major_id " \
563- "LEFT JOIN sys.types types ON types.user_type_id = perm.major_id " \
564- "LEFT JOIN sys.schemas typeschem ON typeschem.schema_id = types.schema_id " \
565- "LEFT JOIN sys.xml_schema_collections schema_coll ON schema_coll.xml_collection_id = perm.major_id " \
566- "LEFT JOIN sys.schemas xmlschem ON xmlschem.schema_id = schema_coll.schema_id " \
567- "LEFT JOIN sys.service_message_types msg_type ON msg_type.message_type_id = perm.major_id " \
568- "LEFT JOIN sys.service_contracts contract ON contract.service_contract_id = perm.major_id " \
569- "LEFT JOIN sys.services svc ON svc.service_id = perm.major_id " \
570- "LEFT JOIN sys.remote_service_bindings binding ON binding.remote_service_binding_id = perm.major_id " \
571- "LEFT JOIN sys.routes route ON route.route_id = perm.major_id " \
572- "LEFT JOIN sys.fulltext_catalogs catalog ON catalog.fulltext_catalog_id = perm.major_id " \
573- "LEFT JOIN sys.symmetric_keys symkey ON symkey.symmetric_key_id = perm.major_id " \
574- "LEFT JOIN sys.certificates cert ON cert.certificate_id = perm.major_id " \
575- "LEFT JOIN sys.asymmetric_keys asymkey ON asymkey.asymmetric_key_id = perm.major_id " \
609+ query = "SELECT " \
610+ "class = perm.class, " \
611+ "state_desc = perm.state_desc, " \
612+ "perm_name = perm.permission_name, " \
613+ "schema_name = permschem.name, " \
614+ "obj_name = obj.name, " \
615+ "obj_schema_name = objschem.name, " \
616+ "col_name = col.name, " \
617+ "imp_name = imp.name, " \
618+ "imp_type = imp.type, " \
619+ "assembly_name = assembly.name, " \
620+ "type_name = types.name, " \
621+ "type_schema = typeschem.name, " \
622+ "schema_coll_name = schema_coll.name, " \
623+ "xml_schema = xmlschem.name, " \
624+ "msg_type_name = msg_type.name, " \
625+ "contract_name = contract.name, " \
626+ "svc_name = svc.name, " \
627+ "binding_name = binding.name, " \
628+ "route_name = route.name, " \
629+ "catalog_name = catalog.name, " \
630+ "symkey_name = symkey.name, " \
631+ "cert_name = cert.name, " \
632+ "asymkey_name = asymkey.name " \
633+ "FROM sys.database_permissions perm " \
634+ "JOIN sys.database_principals prin ON perm.grantee_principal_id = prin.principal_id " \
635+ "LEFT JOIN sys.schemas permschem ON permschem.schema_id = perm.major_id " \
636+ "LEFT JOIN sys.objects obj ON obj.object_id = perm.major_id " \
637+ "LEFT JOIN sys.schemas objschem ON objschem.schema_id = obj.schema_id " \
638+ "LEFT JOIN sys.columns col ON col.object_id = perm.major_id AND col.column_id = perm.minor_id " \
639+ "LEFT JOIN sys.database_principals imp ON imp.principal_id = perm.major_id " \
640+ "LEFT JOIN sys.assemblies assembly ON assembly.assembly_id = perm.major_id " \
641+ "LEFT JOIN sys.types types ON types.user_type_id = perm.major_id " \
642+ "LEFT JOIN sys.schemas typeschem ON typeschem.schema_id = types.schema_id " \
643+ "LEFT JOIN sys.xml_schema_collections schema_coll ON schema_coll.xml_collection_id = perm.major_id " \
644+ "LEFT JOIN sys.schemas xmlschem ON xmlschem.schema_id = schema_coll.schema_id " \
645+ "LEFT JOIN sys.service_message_types msg_type ON msg_type.message_type_id = perm.major_id " \
646+ "LEFT JOIN sys.service_contracts contract ON contract.service_contract_id = perm.major_id " \
647+ "LEFT JOIN sys.services svc ON svc.service_id = perm.major_id " \
648+ "LEFT JOIN sys.remote_service_bindings binding ON binding.remote_service_binding_id = perm.major_id " \
649+ "LEFT JOIN sys.routes route ON route.route_id = perm.major_id " \
650+ "LEFT JOIN sys.fulltext_catalogs catalog ON catalog.fulltext_catalog_id = perm.major_id " \
651+ "LEFT JOIN sys.symmetric_keys symkey ON symkey.symmetric_key_id = perm.major_id " \
652+ "LEFT JOIN sys.certificates cert ON cert.certificate_id = perm.major_id " \
653+ "LEFT JOIN sys.asymmetric_keys asymkey ON asymkey.asymmetric_key_id = perm.major_id " \
576654 "WHERE prin.name = %s"
577655 cursor .execute (query , current_user )
578656 for row in cursor .fetchall ():
579657 # Determine which type of permission this is and create the sql statement accordingly
580- if row ['class' ] == 0 : # Database permission
658+ if row ['class' ] == 0 : # Database permission
581659 permission = row ['perm_name' ]
582- elif row ['class' ] == 1 : # Object or Column
660+ elif row ['class' ] == 1 : # Object or Column
583661 permission = "%s ON OBJECT::%s.%s" % (row ['perm_name' ], row ['obj_schema_name' ], row ['obj_name' ])
584662 if row ['col_name' ]:
585663 permission = "%s (%s) " % (permission , row ['col_name' ])
586- elif row ['class' ] == 3 : # Schema
664+ elif row ['class' ] == 3 : # Schema
587665 permission = "%s ON SCHEMA::%s" % (row ['perm_name' ], row ['schema_name' ])
588- elif row ['class' ] == 4 : # Impersonation (Database Principal)
589- if row ['imp_type' ] == 'S' : # SQL User
666+ elif row ['class' ] == 4 : # Impersonation (Database Principal)
667+ if row ['imp_type' ] == 'S' : # SQL User
590668 permission = "%s ON USER::%s" % (row ['perm_name' ], row ['imp_name' ])
591- elif row ['imp_type' ] == 'R' : # Role
669+ elif row ['imp_type' ] == 'R' : # Role
592670 permission = "%s ON ROLE::%s" % (row ['perm_name' ], row ['imp_name' ])
593- elif row ['imp_type' ] == 'A' : # Application Role
671+ elif row ['imp_type' ] == 'A' : # Application Role
594672 permission = "%s ON APPLICATION ROLE::%s" % (row ['perm_name' ], row ['imp_name' ])
595673 else :
596674 raise ValueError ("Invalid database principal permission type %s" % row ['imp_type' ])
@@ -630,6 +708,7 @@ def apply_database_permissions(cursor, current_user, pending_user):
630708 # Execute the sql
631709 cursor .execute (sql_stmt )
632710
711+
633712def is_rds_replica_database (replica_dict , master_dict ):
634713 """Validates that the database of a secret is a replica of the database of the master secret
635714
0 commit comments