Skip to content

Commit a5ea494

Browse files
sfc-gh-pczajkasfc-gh-turbaszek
authored andcommitted
[async] WIF impersonation for GCP #2496
1 parent 6324c39 commit a5ea494

File tree

6 files changed

+230
-11
lines changed

6 files changed

+230
-11
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,24 @@ async def __open_connection(self):
421421
"errno": ER_INVALID_WIF_SETTINGS,
422422
},
423423
)
424+
if (
425+
self._workload_identity_impersonation_path
426+
and self._workload_identity_provider != AttestationProvider.GCP
427+
):
428+
Error.errorhandler_wrapper(
429+
self,
430+
None,
431+
ProgrammingError,
432+
{
433+
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
434+
"errno": ER_INVALID_WIF_SETTINGS,
435+
},
436+
)
424437
self.auth_class = AuthByWorkloadIdentity(
425438
provider=self._workload_identity_provider,
426439
token=self._token,
427440
entra_resource=self._workload_identity_entra_resource,
441+
impersonation_path=self._workload_identity_impersonation_path,
428442
)
429443
else:
430444
# okta URL, e.g., https://<account>.okta.com/

src/snowflake/connector/aio/_wif_util.py

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727

2828
logger = logging.getLogger(__name__)
2929

30+
GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = (
31+
"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default"
32+
)
33+
3034

3135
async def get_aws_region() -> str:
3236
"""Get the current AWS workload's region."""
@@ -91,30 +95,108 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation:
9195
)
9296

9397

94-
async def create_gcp_attestation(
95-
session_manager: SessionManager | None = None,
96-
) -> WorkloadIdentityAttestation:
97-
"""Tries to create a workload identity attestation for GCP.
98+
async def get_gcp_access_token(session_manager: SessionManager) -> str:
99+
"""Gets a GCP access token from the metadata server.
100+
101+
If the application isn't running on GCP or no credentials were found, raises an error.
102+
"""
103+
try:
104+
res = await session_manager.request(
105+
method="GET",
106+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token",
107+
headers={
108+
"Metadata-Flavor": "Google",
109+
},
110+
)
111+
112+
content = await res.content.read()
113+
response_text = content.decode("utf-8")
114+
return json.loads(response_text)["access_token"]
115+
except Exception as e:
116+
raise ProgrammingError(
117+
msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.",
118+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
119+
)
120+
121+
122+
async def get_gcp_identity_token_via_impersonation(
123+
impersonation_path: list[str], session_manager: SessionManager
124+
) -> str:
125+
"""Gets a GCP identity token from the metadata server.
126+
127+
If the application isn't running on GCP or no credentials were found, raises an error.
128+
"""
129+
if not impersonation_path:
130+
raise ProgrammingError(
131+
msg="Error: impersonation_path cannot be empty.",
132+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
133+
)
134+
135+
current_sa_token = await get_gcp_access_token(session_manager)
136+
impersonation_path = [
137+
f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path
138+
]
139+
try:
140+
res = await session_manager.post(
141+
url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken",
142+
headers={
143+
"Authorization": f"Bearer {current_sa_token}",
144+
"Content-Type": "application/json",
145+
},
146+
json={
147+
"delegates": impersonation_path[:-1],
148+
"audience": SNOWFLAKE_AUDIENCE,
149+
},
150+
)
151+
152+
content = await res.content.read()
153+
response_text = content.decode("utf-8")
154+
return json.loads(response_text)["token"]
155+
except Exception as e:
156+
raise ProgrammingError(
157+
msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.",
158+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
159+
)
160+
161+
162+
async def get_gcp_identity_token(session_manager: SessionManager) -> str:
163+
"""Gets a GCP identity token from the metadata server.
98164
99165
If the application isn't running on GCP or no credentials were found, raises an error.
100166
"""
101167
try:
102168
res = await session_manager.request(
103169
method="GET",
104-
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
170+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}",
105171
headers={
106172
"Metadata-Flavor": "Google",
107173
},
108174
)
109175

110176
content = await res.content.read()
111-
jwt_str = content.decode("utf-8")
177+
return content.decode("utf-8")
112178
except Exception as e:
113179
raise ProgrammingError(
114-
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
180+
msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.",
115181
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
116182
)
117183

184+
185+
async def create_gcp_attestation(
186+
session_manager: SessionManager,
187+
impersonation_path: list[str] | None = None,
188+
) -> WorkloadIdentityAttestation:
189+
"""Tries to create a workload identity attestation for GCP.
190+
191+
If the application isn't running on GCP or no credentials were found, raises an error.
192+
"""
193+
if impersonation_path:
194+
jwt_str = await get_gcp_identity_token_via_impersonation(
195+
impersonation_path, session_manager
196+
)
197+
else:
198+
jwt_str = await get_gcp_identity_token(session_manager)
199+
118200
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
119201
return WorkloadIdentityAttestation(
120202
AttestationProvider.GCP, jwt_str, {"sub": subject}
@@ -189,6 +271,7 @@ async def create_attestation(
189271
provider: AttestationProvider | None,
190272
entra_resource: str | None = None,
191273
token: str | None = None,
274+
impersonation_path: list[str] | None = None,
192275
session_manager: SessionManager | None = None,
193276
) -> WorkloadIdentityAttestation:
194277
"""Entry point to create an attestation using the given provider.
@@ -207,7 +290,7 @@ async def create_attestation(
207290
elif provider == AttestationProvider.AZURE:
208291
return await create_azure_attestation(entra_resource, session_manager)
209292
elif provider == AttestationProvider.GCP:
210-
return await create_gcp_attestation(session_manager)
293+
return await create_gcp_attestation(session_manager, impersonation_path)
211294
elif provider == AttestationProvider.OIDC:
212295
return create_oidc_attestation(token)
213296
else:

src/snowflake/connector/aio/auth/_workload_identity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
provider: AttestationProvider,
2323
token: str | None = None,
2424
entra_resource: str | None = None,
25+
impersonation_path: list[str] | None = None,
2526
**kwargs,
2627
) -> None:
2728
"""Initializes an instance with workload identity authentication."""
@@ -30,6 +31,7 @@ def __init__(
3031
provider=provider,
3132
token=token,
3233
entra_resource=entra_resource,
34+
impersonation_path=impersonation_path,
3335
**kwargs,
3436
)
3537

@@ -44,6 +46,7 @@ async def prepare(
4446
self.provider,
4547
self.entra_resource,
4648
self.token,
49+
self.impersonation_path,
4750
session_manager=(
4851
conn._session_manager.clone(max_retries=0) if conn else None
4952
),

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
from base64 import b64decode
99
from unittest import mock
10+
from unittest.mock import AsyncMock
1011
from urllib.parse import parse_qs, urlparse
1112

1213
import aiohttp
@@ -17,7 +18,7 @@
1718
from snowflake.connector.aio.auth import AuthByWorkloadIdentity
1819
from snowflake.connector.errors import ProgrammingError
1920

20-
from ...csp_helpers import gen_dummy_id_token
21+
from ...csp_helpers import gen_dummy_access_token, gen_dummy_id_token
2122
from .csp_helpers_async import FakeAwsEnvironmentAsync, FakeGceMetadataServiceAsync
2223

2324
logger = logging.getLogger(__name__)
@@ -278,7 +279,7 @@ async def test_explicit_gcp_metadata_server_error_bubbles_up(exception):
278279
with pytest.raises(ProgrammingError) as excinfo:
279280
await auth_class.prepare(conn=None)
280281

281-
assert "Error fetching GCP metadata:" in str(excinfo.value)
282+
assert "Error fetching GCP identity token:" in str(excinfo.value)
282283
assert "Ensure the application is running on GCP." in str(excinfo.value)
283284

284285

@@ -306,6 +307,51 @@ async def test_explicit_gcp_generates_unique_assertion_content(
306307
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}'
307308

308309

310+
@mock.patch("snowflake.connector.aio._session_manager.SessionManager.post")
311+
async def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
312+
mock_post_request, fake_gce_metadata_service: FakeGceMetadataServiceAsync
313+
):
314+
fake_gce_metadata_service.sub = "sa1"
315+
impersonation_path = ["sa2", "sa3"]
316+
sa1_access_token = gen_dummy_access_token("sa1")
317+
sa3_id_token = gen_dummy_id_token("sa3")
318+
319+
# Mock the POST request response
320+
class AsyncResponse:
321+
def __init__(self, content):
322+
self._content = content
323+
self.content = mock.Mock()
324+
self.content.read = AsyncMock(return_value=content)
325+
326+
mock_post_request.return_value = AsyncResponse(
327+
json.dumps({"token": sa3_id_token}).encode("utf-8")
328+
)
329+
330+
auth_class = AuthByWorkloadIdentity(
331+
provider=AttestationProvider.GCP, impersonation_path=impersonation_path
332+
)
333+
await auth_class.prepare(conn=None)
334+
335+
mock_post_request.assert_called_once_with(
336+
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa3:generateIdToken",
337+
headers={
338+
"Authorization": f"Bearer {sa1_access_token}",
339+
"Content-Type": "application/json",
340+
},
341+
json={
342+
"delegates": ["projects/-/serviceAccounts/sa2"],
343+
"audience": "snowflakecomputing.com",
344+
},
345+
)
346+
347+
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"sa3"}'
348+
assert await extract_api_data(auth_class) == {
349+
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
350+
"PROVIDER": "GCP",
351+
"TOKEN": sa3_id_token,
352+
}
353+
354+
309355
# -- Azure Tests --
310356

311357

test/unit/aio/test_connection_async_unit.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ async def test_otel_error_message_async(caplog, mock_post_requests):
605605
"workload_identity_entra_resource",
606606
"api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
607607
),
608+
("workload_identity_impersonation_path", ["subject-b", "subject-c"]),
608609
],
609610
)
610611
async def test_cannot_set_dependent_params_without_wlid_authenticator(
@@ -654,6 +655,79 @@ async def test_workload_identity_provider_is_required_for_wif_authenticator(
654655
assert expected_error_msg in str(excinfo.value)
655656

656657

658+
@pytest.mark.parametrize(
659+
"provider_param",
660+
[
661+
# Strongly-typed values.
662+
AttestationProvider.AWS,
663+
AttestationProvider.AZURE,
664+
AttestationProvider.OIDC,
665+
# String values.
666+
"AWS",
667+
"AZURE",
668+
"OIDC",
669+
],
670+
)
671+
async def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
672+
monkeypatch, provider_param
673+
):
674+
async def mock_authenticate(*_):
675+
pass
676+
677+
with monkeypatch.context() as m:
678+
m.setattr(
679+
"snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
680+
mock_authenticate,
681+
)
682+
683+
with pytest.raises(ProgrammingError) as excinfo:
684+
await snowflake.connector.aio.connect(
685+
account="account",
686+
authenticator="WORKLOAD_IDENTITY",
687+
workload_identity_provider=provider_param,
688+
workload_identity_impersonation_path=[
689+
"sa2@project.iam.gserviceaccount.com"
690+
],
691+
)
692+
assert (
693+
"workload_identity_impersonation_path is currently only supported for GCP."
694+
in str(excinfo.value)
695+
)
696+
697+
698+
@pytest.mark.parametrize(
699+
"provider_param",
700+
[
701+
AttestationProvider.GCP,
702+
"GCP",
703+
],
704+
)
705+
async def test_workload_identity_impersonation_path_supported_for_gcp_provider(
706+
monkeypatch, provider_param
707+
):
708+
async def mock_authenticate(*_):
709+
pass
710+
711+
with monkeypatch.context() as m:
712+
m.setattr(
713+
"snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
714+
mock_authenticate,
715+
)
716+
717+
conn = await snowflake.connector.aio.connect(
718+
account="account",
719+
authenticator="WORKLOAD_IDENTITY",
720+
workload_identity_provider=provider_param,
721+
workload_identity_impersonation_path=[
722+
"sa2@project.iam.gserviceaccount.com"
723+
],
724+
)
725+
assert conn.auth_class.provider == AttestationProvider.GCP
726+
assert conn.auth_class.impersonation_path == [
727+
"sa2@project.iam.gserviceaccount.com"
728+
]
729+
730+
657731
@pytest.mark.parametrize(
658732
"provider_param, parsed_provider",
659733
[

test/wif/test_wif_async.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ async def test_should_authenticate_using_oidc_async():
6262

6363
@pytest.mark.wif
6464
@pytest.mark.aio
65-
@pytest.mark.skip("Impersonation is still being developed")
6665
async def test_should_authenticate_with_impersonation_async():
6766
if not isinstance(IMPERSONATION_PATH, str) or not IMPERSONATION_PATH:
6867
pytest.skip("Skipping test - IMPERSONATION_PATH is not set")

0 commit comments

Comments
 (0)