|
17 | 17 |
|
18 | 18 | from botocore import exceptions |
19 | 19 |
|
| 20 | +from sagemaker import utils |
| 21 | + |
20 | 22 | PRINCIPAL_TEMPLATE = ( |
21 | 23 | '["{account_id}", "{role_arn}", ' '"arn:aws:iam::{account_id}:role/{sagemaker_role}"] ' |
22 | 24 | ) |
@@ -108,7 +110,10 @@ def get_or_create_kms_key( |
108 | 110 | kms_client = sagemaker_session.boto_session.client("kms") |
109 | 111 | kms_key_arn = _get_kms_key_arn(kms_client, alias) |
110 | 112 |
|
111 | | - sts_client = sagemaker_session.boto_session.client("sts") |
| 113 | + region = sagemaker_session.boto_region_name |
| 114 | + sts_client = sagemaker_session.boto_session.client( |
| 115 | + "sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region) |
| 116 | + ) |
112 | 117 | account_id = sts_client.get_caller_identity()["Account"] |
113 | 118 |
|
114 | 119 | if kms_key_arn is None: |
@@ -154,8 +159,13 @@ def get_or_create_kms_key( |
154 | 159 |
|
155 | 160 | @contextlib.contextmanager |
156 | 161 | def bucket_with_encryption(boto_session, sagemaker_role): |
157 | | - account = boto_session.client("sts").get_caller_identity()["Account"] |
158 | | - role_arn = boto_session.client("sts").get_caller_identity()["Arn"] |
| 162 | + region = boto_session.region_name |
| 163 | + sts_client = boto_session.client( |
| 164 | + "sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region) |
| 165 | + ) |
| 166 | + |
| 167 | + account = sts_client.get_caller_identity()["Account"] |
| 168 | + role_arn = sts_client.get_caller_identity()["Arn"] |
159 | 169 |
|
160 | 170 | kms_client = boto_session.client("kms") |
161 | 171 | kms_key_arn = _create_kms_key(kms_client, account, role_arn, sagemaker_role, None) |
|
0 commit comments