Skip to content

Commit f052550

Browse files
sfc-gh-eqinsfc-gh-pmansour
authored andcommitted
Support WIF Impersonation on AWS workloads (#2517)
Co-authored-by: Peter Mansour <peter.mansour@snowflake.com>
1 parent a5ea494 commit f052550

File tree

5 files changed

+87
-23
lines changed

5 files changed

+87
-23
lines changed

src/snowflake/connector/connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,14 +1377,18 @@ def __open_connection(self):
13771377
)
13781378
if (
13791379
self._workload_identity_impersonation_path
1380-
and self._workload_identity_provider != AttestationProvider.GCP
1380+
and self._workload_identity_provider
1381+
not in (
1382+
AttestationProvider.GCP,
1383+
AttestationProvider.AWS,
1384+
)
13811385
):
13821386
Error.errorhandler_wrapper(
13831387
self,
13841388
None,
13851389
ProgrammingError,
13861390
{
1387-
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
1391+
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
13881392
"errno": ER_INVALID_WIF_SETTINGS,
13891393
},
13901394
)

src/snowflake/connector/wif_util.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,29 @@ def get_aws_sts_hostname(region: str, partition: str) -> str:
148148
)
149149

150150

151+
def get_aws_session(impersonation_path: list[str] | None = None):
152+
"""Creates a boto3 session with the appropriate credentials.
153+
154+
If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
155+
"""
156+
session = boto3.session.Session()
157+
158+
impersonation_path = impersonation_path or []
159+
for arn in impersonation_path:
160+
response = session.client("sts").assume_role(
161+
RoleArn=arn, RoleSessionName="identity-federation-session"
162+
)
163+
creds = response["Credentials"]
164+
session = boto3.session.Session(
165+
aws_access_key_id=creds["AccessKeyId"],
166+
aws_secret_access_key=creds["SecretAccessKey"],
167+
aws_session_token=creds["SessionToken"],
168+
)
169+
return session
170+
171+
151172
def create_aws_attestation(
152-
session_manager: SessionManager | None = None,
173+
impersonation_path: list[str] | None = None,
153174
) -> WorkloadIdentityAttestation:
154175
"""Tries to create a workload identity attestation for AWS.
155176
@@ -162,7 +183,8 @@ def create_aws_attestation(
162183
)
163184

164185
# TODO: SNOW-2223669 Investigate if our adapters - containing settings of http traffic - should be passed here as boto urllib3session. Those requests go to local servers, so they do not need Proxy setup or Headers customization in theory. But we may want to have all the traffic going through one class (e.g. Adapter or mixin).
165-
session = boto3.session.Session()
186+
session = get_aws_session(impersonation_path)
187+
166188
aws_creds = session.get_credentials()
167189
if not aws_creds:
168190
raise ProgrammingError(
@@ -396,7 +418,7 @@ def create_attestation(
396418
)
397419

398420
if provider == AttestationProvider.AWS:
399-
return create_aws_attestation(session_manager)
421+
return create_aws_attestation(impersonation_path)
400422
elif provider == AttestationProvider.AZURE:
401423
return create_azure_attestation(entra_resource, session_manager)
402424
elif provider == AttestationProvider.GCP:

test/csp_helpers.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ def gen_dummy_id_token(
4040
)
4141

4242

43-
def gen_dummy_access_token(sub="test-subject") -> str:
43+
def gen_dummy_access_token(sub="test-subject", key="secret") -> str:
4444
"""Generates a dummy access token using the given subject."""
45-
key = "secret"
4645
logger.debug(f"Generating dummy access token for subject {sub}")
4746
return (sub + key).encode("utf-8").hex()
4847

@@ -368,6 +367,11 @@ class FakeAwsEnvironment:
368367
def __init__(self):
369368
# Defaults used for generating a token. Can be overriden in individual tests.
370369
self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab"
370+
# Path of roles that can be assumed. Empty if no impersonation is allowed.
371+
# Can be overriden in individual tests.
372+
self.assumption_path = []
373+
self.assume_role_call_count = 0
374+
371375
self.caller_identity = {"Arn": self.arn}
372376
self.region = "us-east-1"
373377
self.credentials = Credentials(access_key="ak", secret_key="sk")
@@ -376,6 +380,25 @@ def __init__(self):
376380
)
377381
self.metadata_token = "test-token"
378382

383+
def assume_role(self, **kwargs):
384+
if (
385+
self.assumption_path
386+
and kwargs["RoleArn"] == self.assumption_path[self.assume_role_call_count]
387+
):
388+
arn = self.assumption_path[self.assume_role_call_count]
389+
self.assume_role_call_count += 1
390+
return {
391+
"Credentials": {
392+
"AccessKeyId": "access_key",
393+
"SecretAccessKey": "secret_key",
394+
"SessionToken": "session_token",
395+
"Expiration": int(time()) + 60 * 60,
396+
},
397+
"AssumedRoleUser": {"AssumedRoleId": hash(arn), "Arn": arn},
398+
"ResponseMetadata": {},
399+
}
400+
return {}
401+
379402
def get_region(self):
380403
return self.region
381404

@@ -401,6 +424,7 @@ def fetcher_fetch_metadata_token(self):
401424
def boto3_client(self, *args, **kwargs):
402425
mock_client = mock.Mock()
403426
mock_client.get_caller_identity.return_value = self.caller_identity
427+
mock_client.assume_role = self.assume_role
404428
return mock_client
405429

406430
def __enter__(self):
@@ -443,6 +467,9 @@ def __enter__(self):
443467
side_effect=self.boto3_client,
444468
)
445469
)
470+
self.patchers.append(
471+
mock.patch("boto3.session.Session.client", side_effect=self.boto3_client)
472+
)
446473
for patcher in self.patchers:
447474
patcher.__enter__()
448475
return self

test/unit/test_auth_workload_identity.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,22 @@ def test_get_aws_sts_hostname_invalid_inputs(region, partition):
274274
assert "Invalid AWS partition" in str(excinfo.value)
275275

276276

277+
def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path(
278+
fake_aws_environment: FakeAwsEnvironment,
279+
):
280+
impersonation_path = [
281+
"arn:aws:iam::123456789:role/role2",
282+
"arn:aws:iam::123456789:role/role3",
283+
]
284+
fake_aws_environment.assumption_path = impersonation_path
285+
auth_class = AuthByWorkloadIdentity(
286+
provider=AttestationProvider.AWS, impersonation_path=impersonation_path
287+
)
288+
auth_class.prepare(conn=None)
289+
290+
assert fake_aws_environment.assume_role_call_count == 2
291+
292+
277293
# -- GCP Tests --
278294

279295

test/unit/test_connection.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -684,16 +684,14 @@ def test_workload_identity_provider_is_required_for_wif_authenticator(
684684
"provider_param",
685685
[
686686
# Strongly-typed values.
687-
AttestationProvider.AWS,
688687
AttestationProvider.AZURE,
689688
AttestationProvider.OIDC,
690689
# String values.
691-
"AWS",
692690
"AZURE",
693691
"OIDC",
694692
],
695693
)
696-
def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
694+
def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
697695
monkeypatch, provider_param
698696
):
699697
with monkeypatch.context() as m:
@@ -711,20 +709,22 @@ def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
711709
],
712710
)
713711
assert (
714-
"workload_identity_impersonation_path is currently only supported for GCP."
712+
"workload_identity_impersonation_path is currently only supported for GCP and AWS."
715713
in str(excinfo.value)
716714
)
717715

718716

719717
@pytest.mark.parametrize(
720-
"provider_param",
718+
"provider_param,impersonation_path",
721719
[
722-
AttestationProvider.GCP,
723-
"GCP",
720+
(AttestationProvider.GCP, ["sa2@project.iam.gserviceaccount.com"]),
721+
(AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
722+
("GCP", ["sa2@project.iam.gserviceaccount.com"]),
723+
("AWS", ["arn:aws:iam::1234567890:role/role2"]),
724724
],
725725
)
726-
def test_workload_identity_impersonation_path_supported_for_gcp_provider(
727-
monkeypatch, provider_param
726+
def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
727+
monkeypatch, provider_param, impersonation_path
728728
):
729729
with monkeypatch.context() as m:
730730
m.setattr(
@@ -735,14 +735,9 @@ def test_workload_identity_impersonation_path_supported_for_gcp_provider(
735735
account="account",
736736
authenticator="WORKLOAD_IDENTITY",
737737
workload_identity_provider=provider_param,
738-
workload_identity_impersonation_path=[
739-
"sa2@project.iam.gserviceaccount.com"
740-
],
738+
workload_identity_impersonation_path=impersonation_path,
741739
)
742-
assert conn.auth_class.provider == AttestationProvider.GCP
743-
assert conn.auth_class.impersonation_path == [
744-
"sa2@project.iam.gserviceaccount.com"
745-
]
740+
assert conn.auth_class.impersonation_path == impersonation_path
746741

747742

748743
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)