Skip to content

Commit 07fe09a

Browse files
#69 - Fix STS endpoint for China
1 parent 492e9fb commit 07fe09a

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

sagemaker_ssh_helper/aws.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def get_console_domain(self):
1515
return f"{self.region}.console.amazonaws-us-gov.com"
1616
return f"{self.region}.console.aws.amazon.com"
1717

18+
def get_sts_endpoint(self):
19+
if self.region.startswith("cn-"):
20+
return "https://sts.{}.amazonaws.com.cn".format(self.region)
21+
return "https://sts.{}.amazonaws.com".format(self.region)
22+
1823
@staticmethod
1924
def is_arn(arn: str):
2025
"""

sagemaker_ssh_helper/wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def _augment(self):
7373
def _augment_env(self, env):
7474
if self.local_user_id is None:
7575
region = self.sagemaker_session.boto_region_name
76-
endpoint_url = "https://sts.{}.amazonaws.com".format(region)
76+
aws = AWS(region)
77+
endpoint_url = aws.get_sts_endpoint()
7778
caller_id = boto3.client("sts", region_name=region, endpoint_url=endpoint_url).get_caller_identity()
7879
user_id = caller_id.get('UserId')
7980
else:

0 commit comments

Comments
 (0)