@@ -17,8 +17,8 @@ def lambda_handler(event, context):
1717 This handler uses the master-user rotation scheme to rotate an RDS MariaDB user credential. During the first rotation, this
1818 scheme logs into the database as the master user, creates a new user (appending _clone to the username), and grants the
1919 new user all of the permissions from the user being rotated. Once the secret is in this state, every subsequent rotation
20- simply creates a new secret with the AWSPREVIOUS user credentials, adds any missing permissions that are in the current
21- secret, changes that user's password, and then marks the latest secret as AWSCURRENT.
20+ simply creates a new secret with the AWSPREVIOUS user credentials, changes that user's password, and then marks the
21+ latest secret as AWSCURRENT.
2222
2323 The Secret SecretString is expected to be a JSON string with the following format:
2424 {
@@ -154,7 +154,7 @@ def set_secret(service_client, arn, token):
154154 """
155155 current_dict = get_secret_dict (service_client , arn , "AWSCURRENT" )
156156 pending_dict = get_secret_dict (service_client , arn , "AWSPENDING" , token )
157-
157+
158158 # First try to login with the pending secret, if it succeeds, return
159159 conn = get_connection (pending_dict )
160160 if conn :
@@ -202,8 +202,22 @@ def set_secret(service_client, arn, token):
202202 cur .execute ("SHOW GRANTS FOR %s" , current_dict ['username' ])
203203 for row in cur .fetchall ():
204204 grant = row [0 ].split (' TO ' )
205- new_grant_escaped = grant [0 ].replace ('%' ,'%%' ) # % is a special character in Python format strings.
205+ new_grant_escaped = grant [0 ].replace ('%' , '%%' ) # % is a special character in Python format strings.
206206 cur .execute (new_grant_escaped + " TO %s IDENTIFIED BY %s" , (pending_dict ['username' ], pending_dict ['password' ]))
207+
208+ # Copy TLS options to the new user
209+ cur .execute ("SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject FROM mysql.user WHERE User = %s" , current_dict ['username' ])
210+ tls_options = cur .fetchone ()
211+ ssl_type = tls_options [0 ]
212+ if not ssl_type :
213+ cur .execute ("ALTER USER %s@'%%' REQUIRE NONE" , pending_dict ['username' ])
214+ elif ssl_type == "ANY" :
215+ cur .execute ("ALTER USER %s@'%%' REQUIRE SSL" , pending_dict ['username' ])
216+ elif ssl_type == "X509" :
217+ cur .execute ("ALTER USER %s@'%%' REQUIRE X509" , pending_dict ['username' ])
218+ else :
219+ cur .execute ("ALTER USER %s@'%%' REQUIRE CIPHER %s AND ISSUER %s AND SUBJECT %s" , (pending_dict ['username' ], tls_options [1 ], tls_options [2 ], tls_options [3 ]))
220+
207221 conn .commit ()
208222 logger .info ("setSecret: Successfully set password for %s in MariaDB DB for secret arn %s." % (pending_dict ['username' ], arn ))
209223 finally :
@@ -286,8 +300,9 @@ def finish_secret(service_client, arn, token):
286300def get_connection (secret_dict ):
287301 """Gets a connection to MariaDB DB from a secret dictionary
288302
289- This helper function tries to connect to the database grabbing connection info
290- from the secret dictionary. If successful, it returns the connection, else None
303+ This helper function uses connectivity information from the secret dictionary to initiate
304+ connection attempt(s) to the database. Will attempt a fallback, non-SSL connection when
305+ initial connection fails using SSL and fall_back is True.
291306
292307 Args:
293308 secret_dict (dict): The Secret Dictionary
@@ -299,12 +314,88 @@ def get_connection(secret_dict):
299314 KeyError: If the secret json does not contain the expected keys
300315
301316 """
317+ # Parse and validate the secret JSON string
302318 port = int (secret_dict ['port' ]) if 'port' in secret_dict else 3306
303319 dbname = secret_dict ['dbname' ] if 'dbname' in secret_dict else None
304320
321+ # Get SSL connectivity configuration
322+ use_ssl , fall_back = get_ssl_config (secret_dict )
323+
324+ # if an 'ssl' key is not found or does not contain a valid value, attempt an SSL connection and fall back to non-SSL on failure
325+ conn = connect_and_authenticate (secret_dict , port , dbname , use_ssl )
326+ if conn or not fall_back :
327+ return conn
328+ else :
329+ return connect_and_authenticate (secret_dict , port , dbname , False )
330+
331+
332+ def get_ssl_config (secret_dict ):
333+ """Gets the desired SSL and fall back behavior using a secret dictionary
334+
335+ This helper function uses the existance and value the 'ssl' key in a secret dictionary
336+ to determine desired SSL connectivity configuration. Its behavior is as follows:
337+ - 'ssl' key DNE or invalid type/value: return True, True
338+ - 'ssl' key is bool: return secret_dict['ssl'], False
339+ - 'ssl' key equals "true" ignoring case: return True, False
340+ - 'ssl' key equals "false" ignoring case: return False, False
341+
342+ Args:
343+ secret_dict (dict): The Secret Dictionary
344+
345+ Returns:
346+ Tuple(use_ssl, fall_back): SSL configuration
347+ - use_ssl (bool): Flag indicating if an SSL connection should be attempted
348+ - fall_back (bool): Flag indicating if non-SSL connection should be attempted if SSL connection fails
349+
350+ """
351+ # Default to True for SSL and fall_back mode if 'ssl' key DNE
352+ if 'ssl' not in secret_dict :
353+ return True , True
354+
355+ # Handle type bool
356+ if isinstance (secret_dict ['ssl' ], bool ):
357+ return secret_dict ['ssl' ], False
358+
359+ # Handle type string
360+ if isinstance (secret_dict ['ssl' ], str ):
361+ ssl = secret_dict ['ssl' ].lower ()
362+ if ssl == "true" :
363+ return True , False
364+ elif ssl == "false" :
365+ return False , False
366+ else :
367+ # Invalid string value, default to True for both SSL and fall_back mode
368+ return True , True
369+
370+ # Invalid type, default to True for both SSL and fall_back mode
371+ return True , True
372+
373+
374+ def connect_and_authenticate (secret_dict , port , dbname , use_ssl ):
375+ """Attempt to connect and authenticate to a MariaDB instance
376+
377+ This helper function tries to connect to the database using connectivity info passed in.
378+ If successful, it returns the connection, else None
379+
380+ Args:
381+ - secret_dict (dict): The Secret Dictionary
382+ - port (int): The databse port to connect to
383+ - dbname (str): Name of the database
384+ - use_ssl (bool): Flag indicating whether connection should use SSL/TLS
385+
386+ Returns:
387+ Connection: The pymongo.database.Database object if successful. None otherwise
388+
389+ Raises:
390+ KeyError: If the secret json does not contain the expected keys
391+
392+ """
393+ ssl = {'ca' : '/etc/pki/tls/cert.pem' } if use_ssl else None
394+
305395 # Try to obtain a connection to the db
306396 try :
307- conn = pymysql .connect (secret_dict ['host' ], user = secret_dict ['username' ], passwd = secret_dict ['password' ], port = port , db = dbname , connect_timeout = 5 )
397+ # Checks hostname and verifies server certificate implictly when 'ca' key is in 'ssl' dictionary
398+ conn = pymysql .connect (secret_dict ['host' ], user = secret_dict ['username' ], passwd = secret_dict ['password' ], port = port , db = dbname , connect_timeout = 5 , ssl = ssl )
308399 return conn
309400 except pymysql .OperationalError :
310401 return None
@@ -378,6 +469,7 @@ def get_alt_username(current_username):
378469 raise ValueError ("Unable to clone user, username length with _clone appended would exceed 80 characters" )
379470 return new_username
380471
472+
381473def is_rds_replica_database (replica_dict , master_dict ):
382474 """Validates that the database of a secret is a replica of the database of the master secret
383475
0 commit comments