1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
1414
15+ import contextlib
16+ import json
17+
1518from botocore import exceptions
1619
17- KEY_ALIAS = "SageMakerIntegTestKmsKey"
20+ PRINCIPAL_TEMPLATE = '["{account_id}", "{role_arn}", ' \
21+ '"arn:aws:iam::{account_id}:role/{sagemaker_role}"] '
22+
23+ KEY_ALIAS = 'SageMakerTestKMSKey'
24+ KMS_S3_ALIAS = 'SageMakerTestS3KMSKey'
25+ POLICY_NAME = 'default'
1826KEY_POLICY = '''
1927{{
2028 "Version": "2012-10-17",
21- "Id": "sagemaker-kms-integ-test-policy ",
29+ "Id": "{id} ",
2230 "Statement": [
2331 {{
2432 "Sid": "Enable IAM User Permissions",
2533 "Effect": "Allow",
2634 "Principal": {{
27- "AWS": "*"
35+ "AWS": {principal}
2836 }},
2937 "Action": "kms:*",
3038 "Resource": "*"
@@ -42,22 +50,75 @@ def _get_kms_key_arn(kms_client, alias):
4250 return None
4351
4452
45- def _create_kms_key (kms_client , account_id ):
53+ def _get_kms_key_id (kms_client , alias ):
54+ try :
55+ response = kms_client .describe_key (KeyId = 'alias/' + alias )
56+ return response ['KeyMetadata' ]['KeyId' ]
57+ except kms_client .exceptions .NotFoundException :
58+ return None
59+
60+
61+ def _create_kms_key (kms_client ,
62+ account_id ,
63+ role_arn = None ,
64+ sagemaker_role = 'SageMakerRole' ,
65+ alias = KEY_ALIAS ):
66+ if role_arn :
67+ principal = PRINCIPAL_TEMPLATE .format (account_id = account_id ,
68+ role_arn = role_arn ,
69+ sagemaker_role = sagemaker_role )
70+ else :
71+ principal = "{account_id}" .format (account_id = account_id )
72+
4673 response = kms_client .create_key (
47- Policy = KEY_POLICY .format (account_id = account_id ),
74+ Policy = KEY_POLICY .format (id = POLICY_NAME , principal = principal , sagemaker_role = sagemaker_role ),
4875 Description = 'KMS key for SageMaker Python SDK integ tests' ,
4976 )
5077 key_arn = response ['KeyMetadata' ]['Arn' ]
51- response = kms_client .create_alias (AliasName = 'alias/' + KEY_ALIAS , TargetKeyId = key_arn )
78+
79+ if alias :
80+ kms_client .create_alias (AliasName = 'alias/' + alias , TargetKeyId = key_arn )
5281 return key_arn
5382
5483
55- def get_or_create_kms_key (kms_client , account_id ):
56- kms_key_arn = _get_kms_key_arn (kms_client , KEY_ALIAS )
57- if kms_key_arn is not None :
58- return kms_key_arn
59- else :
60- return _create_kms_key (kms_client , account_id )
84+ def _add_role_to_policy (kms_client ,
85+ account_id ,
86+ role_arn ,
87+ alias = KEY_ALIAS ,
88+ sagemaker_role = 'SageMakerRole' ):
89+ key_id = _get_kms_key_id (kms_client , alias )
90+ policy = kms_client .get_key_policy (KeyId = key_id , PolicyName = POLICY_NAME )
91+ policy = json .loads (policy ['Policy' ])
92+ principal = policy ['Statement' ][0 ]['Principal' ]['AWS' ]
93+
94+ if role_arn not in principal or sagemaker_role not in principal :
95+ principal = PRINCIPAL_TEMPLATE .format (account_id = account_id ,
96+ role_arn = role_arn ,
97+ sagemaker_role = sagemaker_role )
98+
99+ kms_client .put_key_policy (KeyId = key_id ,
100+ PolicyName = POLICY_NAME ,
101+ Policy = KEY_POLICY .format (id = POLICY_NAME , principal = principal ))
102+
103+
104+ def get_or_create_kms_key (kms_client ,
105+ account_id ,
106+ role_arn = None ,
107+ alias = KEY_ALIAS ,
108+ sagemaker_role = 'SageMakerRole' ):
109+ kms_key_arn = _get_kms_key_arn (kms_client , alias )
110+
111+ if kms_key_arn is None :
112+ return _create_kms_key (kms_client , account_id , role_arn , sagemaker_role , alias )
113+
114+ if role_arn :
115+ _add_role_to_policy (kms_client ,
116+ account_id ,
117+ role_arn ,
118+ alias ,
119+ sagemaker_role )
120+
121+ return kms_key_arn
61122
62123
63124KMS_BUCKET_POLICY = """{
@@ -92,9 +153,13 @@ def get_or_create_kms_key(kms_client, account_id):
92153}"""
93154
94155
95- def get_or_create_bucket_with_encryption (boto_session ):
156+ @contextlib .contextmanager
157+ def bucket_with_encryption (boto_session , sagemaker_role ):
96158 account = boto_session .client ('sts' ).get_caller_identity ()['Account' ]
97- kms_key_arn = get_or_create_kms_key (boto_session .client ('kms' ), account )
159+ role_arn = boto_session .client ('sts' ).get_caller_identity ()['Arn' ]
160+
161+ kms_client = boto_session .client ('kms' )
162+ kms_key_arn = _create_kms_key (kms_client , account , role_arn , sagemaker_role , None )
98163
99164 region = boto_session .region_name
100165 bucket_name = 'sagemaker-{}-{}-with-kms' .format (region , account )
@@ -132,4 +197,6 @@ def get_or_create_bucket_with_encryption(boto_session):
132197 Policy = KMS_BUCKET_POLICY % (bucket_name , bucket_name )
133198 )
134199
135- return 's3://' + bucket_name , kms_key_arn
200+ yield 's3://' + bucket_name , kms_key_arn
201+
202+ kms_client .schedule_key_deletion (KeyId = kms_key_arn , PendingWindowInDays = 7 )
0 commit comments