Skip to content

Commit e00c593

Browse files
sfc-gh-pczajkasfc-gh-turbaszek
authored andcommitted
[async] WIF impersonation for AWS (#2517)
1 parent f052550 commit e00c593

File tree

5 files changed

+65
-21
lines changed

5 files changed

+65
-21
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,14 +423,18 @@ async def __open_connection(self):
423423
)
424424
if (
425425
self._workload_identity_impersonation_path
426-
and self._workload_identity_provider != AttestationProvider.GCP
426+
and self._workload_identity_provider
427+
not in (
428+
AttestationProvider.GCP,
429+
AttestationProvider.AWS,
430+
)
427431
):
428432
Error.errorhandler_wrapper(
429433
self,
430434
None,
431435
ProgrammingError,
432436
{
433-
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
437+
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
434438
"errno": ER_INVALID_WIF_SETTINGS,
435439
},
436440
)

src/snowflake/connector/aio/_wif_util.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,31 @@ async def get_aws_region() -> str:
4949
return region
5050

5151

52-
async def create_aws_attestation() -> WorkloadIdentityAttestation:
52+
async def get_aws_session(impersonation_path: list[str] | None = None):
53+
"""Creates an aioboto3 session with the appropriate credentials.
54+
55+
If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
56+
"""
57+
session = aioboto3.Session()
58+
59+
impersonation_path = impersonation_path or []
60+
for arn in impersonation_path:
61+
async with session.client("sts") as sts_client:
62+
response = await sts_client.assume_role(
63+
RoleArn=arn, RoleSessionName="identity-federation-session"
64+
)
65+
creds = response["Credentials"]
66+
session = aioboto3.Session(
67+
aws_access_key_id=creds["AccessKeyId"],
68+
aws_secret_access_key=creds["SecretAccessKey"],
69+
aws_session_token=creds["SessionToken"],
70+
)
71+
return session
72+
73+
74+
async def create_aws_attestation(
75+
impersonation_path: list[str] | None = None,
76+
) -> WorkloadIdentityAttestation:
5377
"""Tries to create a workload identity attestation for AWS.
5478
5579
If the application isn't running on AWS or no credentials were found, raises an error.
@@ -60,7 +84,7 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation:
6084
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
6185
)
6286

63-
session = aioboto3.Session()
87+
session = await get_aws_session(impersonation_path)
6488
aws_creds = await session.get_credentials()
6589
if not aws_creds:
6690
raise ProgrammingError(
@@ -286,7 +310,7 @@ async def create_attestation(
286310
)
287311

288312
if provider == AttestationProvider.AWS:
289-
return await create_aws_attestation()
313+
return await create_aws_attestation(impersonation_path)
290314
elif provider == AttestationProvider.AZURE:
291315
return await create_azure_attestation(entra_resource, session_manager)
292316
elif provider == AttestationProvider.GCP:

test/unit/aio/csp_helpers_async.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ async def async_get_arn():
202202
)
203203

204204
# Mock the async STS client for direct aioboto3 usage
205+
fake_aws_self = self
206+
205207
class MockStsClient:
206208
async def __aenter__(self):
207209
return self
@@ -212,6 +214,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
212214
async def get_caller_identity(self):
213215
return await async_get_caller_identity()
214216

217+
async def assume_role(self, **kwargs):
218+
return fake_aws_self.assume_role(**kwargs)
219+
215220
def mock_session_client(service_name):
216221
if service_name == "sts":
217222
return MockStsClient()

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,22 @@ async def test_explicit_aws_generates_unique_assertion_content(
252252
)
253253

254254

255+
async def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path(
256+
fake_aws_environment: FakeAwsEnvironmentAsync,
257+
):
258+
impersonation_path = [
259+
"arn:aws:iam::123456789:role/role2",
260+
"arn:aws:iam::123456789:role/role3",
261+
]
262+
fake_aws_environment.assumption_path = impersonation_path
263+
auth_class = AuthByWorkloadIdentity(
264+
provider=AttestationProvider.AWS, impersonation_path=impersonation_path
265+
)
266+
await auth_class.prepare(conn=None)
267+
268+
assert fake_aws_environment.assume_role_call_count == 2
269+
270+
255271
# -- GCP Tests --
256272

257273

test/unit/aio/test_connection_async_unit.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -659,16 +659,14 @@ async def test_workload_identity_provider_is_required_for_wif_authenticator(
659659
"provider_param",
660660
[
661661
# Strongly-typed values.
662-
AttestationProvider.AWS,
663662
AttestationProvider.AZURE,
664663
AttestationProvider.OIDC,
665664
# String values.
666-
"AWS",
667665
"AZURE",
668666
"OIDC",
669667
],
670668
)
671-
async def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
669+
async def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
672670
monkeypatch, provider_param
673671
):
674672
async def mock_authenticate(*_):
@@ -690,20 +688,22 @@ async def mock_authenticate(*_):
690688
],
691689
)
692690
assert (
693-
"workload_identity_impersonation_path is currently only supported for GCP."
691+
"workload_identity_impersonation_path is currently only supported for GCP and AWS."
694692
in str(excinfo.value)
695693
)
696694

697695

698696
@pytest.mark.parametrize(
699-
"provider_param",
697+
"provider_param,impersonation_path",
700698
[
701-
AttestationProvider.GCP,
702-
"GCP",
699+
(AttestationProvider.GCP, ["sa2@project.iam.gserviceaccount.com"]),
700+
(AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
701+
("GCP", ["sa2@project.iam.gserviceaccount.com"]),
702+
("AWS", ["arn:aws:iam::1234567890:role/role2"]),
703703
],
704704
)
705-
async def test_workload_identity_impersonation_path_supported_for_gcp_provider(
706-
monkeypatch, provider_param
705+
async def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
706+
monkeypatch, provider_param, impersonation_path
707707
):
708708
async def mock_authenticate(*_):
709709
pass
@@ -718,14 +718,9 @@ async def mock_authenticate(*_):
718718
account="account",
719719
authenticator="WORKLOAD_IDENTITY",
720720
workload_identity_provider=provider_param,
721-
workload_identity_impersonation_path=[
722-
"sa2@project.iam.gserviceaccount.com"
723-
],
721+
workload_identity_impersonation_path=impersonation_path,
724722
)
725-
assert conn.auth_class.provider == AttestationProvider.GCP
726-
assert conn.auth_class.impersonation_path == [
727-
"sa2@project.iam.gserviceaccount.com"
728-
]
723+
assert conn.auth_class.impersonation_path == impersonation_path
729724

730725

731726
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)