Skip to content

Commit d14d666

Browse files
authored
Merge pull request #15 from willtong1234/develop
Support DocumentDB and MySQL 8
2 parents 94722e6 + 13e6deb commit d14d666

File tree

4 files changed

+65
-6
lines changed

4 files changed

+65
-6
lines changed

SecretsManagerMongoDBRotationMultiUser/lambda_function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def lambda_handler(event, context):
2929
'dbname': <optional: database name>,
3030
'port': <optional: if not specified, default port 27017 will be used>,
3131
'masterarn': <required: the arn of the master secret which will be used to create users/change passwords>
32+
'ssl': <optional: if not specified, defaults to false. This must be true if being used for DocumentDB rotations where the cluster has TLS enabled>
3233
}
3334
3435
Args:
@@ -282,10 +283,13 @@ def get_connection(secret_dict):
282283
"""
283284
port = int(secret_dict['port']) if 'port' in secret_dict else 27017
284285
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else "admin"
286+
ssl = False
287+
if 'ssl' in secret_dict:
288+
ssl = (secret_dict['ssl'].lower() == "true") if type(secret_dict['ssl']) is str else bool(secret_dict['ssl'])
285289

286290
# Try to obtain a connection to the db
287291
try:
288-
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000)
292+
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000, ssl=ssl)
289293
db = client[dbname]
290294
db.authenticate(secret_dict['username'], secret_dict['password'])
291295
return db

SecretsManagerMongoDBRotationSingleUser/lambda_function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def lambda_handler(event, context):
2626
'password': <required: password>,
2727
'dbname': <optional: database name>,
2828
'port': <optional: if not specified, default port 27017 will be used>
29+
'ssl': <optional: if not specified, defaults to false. This must be true if being used for DocumentDB rotations where the cluster has TLS enabled>
2930
}
3031
3132
Args:
@@ -261,10 +262,13 @@ def get_connection(secret_dict):
261262
# Parse and validate the secret JSON string
262263
port = int(secret_dict['port']) if 'port' in secret_dict else 27017
263264
dbname = secret_dict['dbname'] if 'dbname' in secret_dict else "admin"
265+
ssl = False
266+
if 'ssl' in secret_dict:
267+
ssl = (secret_dict['ssl'].lower() == "true") if type(secret_dict['ssl']) is str else bool(secret_dict['ssl'])
264268

265269
# Try to obtain a connection to the db
266270
try:
267-
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000)
271+
client = MongoClient(host=secret_dict['host'], port=port, connectTimeoutMS=5000, serverSelectionTimeoutMS=5000, ssl=ssl)
268272
db = client[dbname]
269273
db.authenticate(secret_dict['username'], secret_dict['password'])
270274
return db

SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,24 @@ def set_secret(service_client, arn, token):
182182
# Now set the password to the pending password
183183
try:
184184
with conn.cursor() as cur:
185-
# List the grants on the current user and add them to the pending user.
186-
# This also creates the user if it does not already exist
185+
cur.execute("SELECT User FROM mysql.user WHERE User = %s", pending_dict['username'])
186+
# Create the user if it does not exist
187+
if cur.rowcount == 0:
188+
cur.execute("CREATE USER %s IDENTIFIED BY %s", (pending_dict['username'], pending_dict['password']))
189+
190+
# Copy grants to the new user
187191
cur.execute("SHOW GRANTS FOR %s", current_dict['username'])
188192
for row in cur.fetchall():
189193
grant = row[0].split(' TO ')
190194
new_grant = "%s TO '%s'" % (grant[0], pending_dict['username'])
191195
new_grant_escaped = new_grant.replace('%','%%') # % is a special character in Python format strings.
192-
cur.execute(new_grant_escaped + " IDENTIFIED BY %s", pending_dict['password'])
196+
cur.execute(new_grant_escaped)
197+
198+
# Set the password for the user and commit
199+
cur.execute("SELECT VERSION()")
200+
ver = cur.fetchone()
201+
password_option = get_password_option(ver[0])
202+
cur.execute("SET PASSWORD FOR %s = " + password_option, (pending_dict['username'], pending_dict['password']))
193203
conn.commit()
194204
logger.info("setSecret: Successfully set password for %s in MySQL DB for secret arn %s." % (pending_dict['username'], arn))
195205
finally:
@@ -363,3 +373,22 @@ def get_alt_username(current_username):
363373
if len(new_username) > 16:
364374
raise ValueError("Unable to clone user, username length with _clone appended would exceed 16 characters")
365375
return new_username
376+
377+
378+
def get_password_option(version):
379+
"""Gets the password option template string to use for the SET PASSWORD sql query
380+
381+
This helper function takes in the mysql version and returns the appropriate password option template string that can
382+
be used in the SET PASSWORD query for that mysql version.
383+
384+
Args:
385+
version (string): The mysql database version
386+
387+
Returns:
388+
PasswordOption: The password option string
389+
390+
"""
391+
if version.startswith("8"):
392+
return "%s"
393+
else:
394+
return "PASSWORD(%s)"

SecretsManagerRDSMySQLRotationSingleUser/lambda_function.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ def set_secret(service_client, arn, token):
168168
# Now set the password to the pending password
169169
try:
170170
with conn.cursor() as cur:
171-
cur.execute("SET PASSWORD = PASSWORD(%s)", pending_dict['password'])
171+
cur.execute("SELECT VERSION()")
172+
ver = cur.fetchone()
173+
password_option = get_password_option(ver[0])
174+
cur.execute("SET PASSWORD = " + password_option, pending_dict['password'])
172175
conn.commit()
173176
logger.info("setSecret: Successfully set password for user %s in MySQL DB for secret arn %s." % (pending_dict['username'], arn))
174177
finally:
@@ -315,3 +318,22 @@ def get_secret_dict(service_client, arn, stage, token=None):
315318

316319
# Parse and return the secret JSON string
317320
return secret_dict
321+
322+
323+
def get_password_option(version):
324+
"""Gets the password option template string to use for the SET PASSWORD sql query
325+
326+
This helper function takes in the mysql version and returns the appropriate password option template string that can
327+
be used in the SET PASSWORD query for that mysql version.
328+
329+
Args:
330+
version (string): The mysql database version
331+
332+
Returns:
333+
PasswordOption: The password option string
334+
335+
"""
336+
if version.startswith("8"):
337+
return "%s"
338+
else:
339+
return "PASSWORD(%s)"

0 commit comments

Comments
 (0)