Skip to content

Commit 6324c39

Browse files
sfc-gh-eqinsfc-gh-pmansour
authored andcommitted
Support WIF Impersonation on GCP workloads (#2496)
Co-authored-by: Peter Mansour <peter.mansour@snowflake.com>
1 parent e5aeea7 commit 6324c39

File tree

7 files changed

+237
-11
lines changed

7 files changed

+237
-11
lines changed

src/snowflake/connector/auth/workload_identity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def __init__(
5555
provider: AttestationProvider,
5656
token: str | None = None,
5757
entra_resource: str | None = None,
58+
impersonation_path: list[str] | None = None,
5859
**kwargs,
5960
) -> None:
6061
super().__init__(**kwargs)
6162
self.provider = provider
6263
self.token = token
6364
self.entra_resource = entra_resource
65+
self.impersonation_path = impersonation_path
6466

6567
self.attestation: WorkloadIdentityAttestation | None = None
6668

@@ -85,6 +87,7 @@ def prepare(
8587
self.provider,
8688
self.entra_resource,
8789
self.token,
90+
self.impersonation_path,
8891
session_manager=(
8992
conn._session_manager.clone(max_retries=0) if conn else None
9093
),

src/snowflake/connector/connection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def _get_private_bytes_from_file(
229229
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
230230
"workload_identity_provider": (None, (type(None), AttestationProvider)),
231231
"workload_identity_entra_resource": (None, (type(None), str)),
232+
"workload_identity_impersonation_path": (None, (type(None), list[str])),
232233
"mfa_callback": (None, (type(None), Callable)),
233234
"password_callback": (None, (type(None), Callable)),
234235
"auth_class": (None, (type(None), AuthByPlugin)),
@@ -1374,10 +1375,24 @@ def __open_connection(self):
13741375
"errno": ER_INVALID_WIF_SETTINGS,
13751376
},
13761377
)
1378+
if (
1379+
self._workload_identity_impersonation_path
1380+
and self._workload_identity_provider != AttestationProvider.GCP
1381+
):
1382+
Error.errorhandler_wrapper(
1383+
self,
1384+
None,
1385+
ProgrammingError,
1386+
{
1387+
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
1388+
"errno": ER_INVALID_WIF_SETTINGS,
1389+
},
1390+
)
13771391
self.auth_class = AuthByWorkloadIdentity(
13781392
provider=self._workload_identity_provider,
13791393
token=self._token,
13801394
entra_resource=self._workload_identity_entra_resource,
1395+
impersonation_path=self._workload_identity_impersonation_path,
13811396
)
13821397
else:
13831398
# okta URL, e.g., https://<account>.okta.com/
@@ -1550,6 +1565,7 @@ def __config(self, **kwargs):
15501565
workload_identity_dependent_options = [
15511566
"workload_identity_provider",
15521567
"workload_identity_entra_resource",
1568+
"workload_identity_impersonation_path",
15531569
]
15541570
for dependent_option in workload_identity_dependent_options:
15551571
if (

src/snowflake/connector/wif_util.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
logger = logging.getLogger(__name__)
2424
SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"
2525
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad"
26+
GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = (
27+
"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default"
28+
)
2629

2730

2831
@unique
@@ -193,29 +196,103 @@ def create_aws_attestation(
193196
)
194197

195198

196-
def create_gcp_attestation(
197-
session_manager: SessionManager | None = None,
198-
) -> WorkloadIdentityAttestation:
199-
"""Tries to create a workload identity attestation for GCP.
199+
def get_gcp_access_token(session_manager: SessionManager) -> str:
200+
"""Gets a GCP access token from the metadata server.
201+
202+
If the application isn't running on GCP or no credentials were found, raises an error.
203+
"""
204+
try:
205+
res = session_manager.request(
206+
method="GET",
207+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token",
208+
headers={
209+
"Metadata-Flavor": "Google",
210+
},
211+
)
212+
res.raise_for_status()
213+
return res.json()["access_token"]
214+
except Exception as e:
215+
raise ProgrammingError(
216+
msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.",
217+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
218+
)
219+
220+
221+
def get_gcp_identity_token_via_impersonation(
222+
impersonation_path: list[str], session_manager: SessionManager
223+
) -> str:
224+
"""Gets a GCP identity token from the metadata server.
225+
226+
If the application isn't running on GCP or no credentials were found, raises an error.
227+
"""
228+
if not impersonation_path:
229+
raise ProgrammingError(
230+
msg="Error: impersonation_path cannot be empty.",
231+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
232+
)
233+
234+
current_sa_token = get_gcp_access_token(session_manager)
235+
impersonation_path = [
236+
f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path
237+
]
238+
try:
239+
res = session_manager.post(
240+
url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken",
241+
headers={
242+
"Authorization": f"Bearer {current_sa_token}",
243+
"Content-Type": "application/json",
244+
},
245+
json={
246+
"delegates": impersonation_path[:-1],
247+
"audience": SNOWFLAKE_AUDIENCE,
248+
},
249+
)
250+
res.raise_for_status()
251+
return res.json()["token"]
252+
except Exception as e:
253+
raise ProgrammingError(
254+
msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.",
255+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
256+
)
257+
258+
259+
def get_gcp_identity_token(session_manager: SessionManager) -> str:
260+
"""Gets a GCP identity token from the metadata server.
200261
201262
If the application isn't running on GCP or no credentials were found, raises an error.
202263
"""
203264
try:
204265
res = session_manager.request(
205266
method="GET",
206-
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
267+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}",
207268
headers={
208269
"Metadata-Flavor": "Google",
209270
},
210271
)
211272
res.raise_for_status()
273+
return res.content.decode("utf-8")
212274
except Exception as e:
213275
raise ProgrammingError(
214-
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
276+
msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.",
215277
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
216278
)
217279

218-
jwt_str = res.content.decode("utf-8")
280+
281+
def create_gcp_attestation(
282+
session_manager: SessionManager,
283+
impersonation_path: list[str] | None = None,
284+
) -> WorkloadIdentityAttestation:
285+
"""Tries to create a workload identity attestation for GCP.
286+
287+
If the application isn't running on GCP or no credentials were found, raises an error.
288+
"""
289+
if impersonation_path:
290+
jwt_str = get_gcp_identity_token_via_impersonation(
291+
impersonation_path, session_manager
292+
)
293+
else:
294+
jwt_str = get_gcp_identity_token(session_manager)
295+
219296
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
220297
return WorkloadIdentityAttestation(
221298
AttestationProvider.GCP, jwt_str, {"sub": subject}
@@ -304,6 +381,7 @@ def create_attestation(
304381
provider: AttestationProvider,
305382
entra_resource: str | None = None,
306383
token: str | None = None,
384+
impersonation_path: list[str] | None = None,
307385
session_manager: SessionManager | None = None,
308386
) -> WorkloadIdentityAttestation:
309387
"""Entry point to create an attestation using the given provider.
@@ -322,7 +400,7 @@ def create_attestation(
322400
elif provider == AttestationProvider.AZURE:
323401
return create_azure_attestation(entra_resource, session_manager)
324402
elif provider == AttestationProvider.GCP:
325-
return create_gcp_attestation(session_manager)
403+
return create_gcp_attestation(session_manager, impersonation_path)
326404
elif provider == AttestationProvider.OIDC:
327405
return create_oidc_attestation(token)
328406
else:

test/csp_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def gen_dummy_id_token(
4040
)
4141

4242

43+
def gen_dummy_access_token(sub="test-subject") -> str:
44+
"""Generates a dummy access token using the given subject."""
45+
key = "secret"
46+
logger.debug(f"Generating dummy access token for subject {sub}")
47+
return (sub + key).encode("utf-8").hex()
48+
49+
4350
def build_response(content: bytes, status_code: int = 200, headers=None) -> Response:
4451
"""Builds a requests.Response object with the given status code and content."""
4552
response = Response()
@@ -285,6 +292,19 @@ def handle_request(self, method, parsed_url, headers, timeout):
285292
audience = query_string["audience"][0]
286293
self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience)
287294
return build_response(self.token.encode("utf-8"))
295+
elif (
296+
method == "GET"
297+
and parsed_url.path
298+
== "/computeMetadata/v1/instance/service-accounts/default/token"
299+
and headers.get("Metadata-Flavor") == "Google"
300+
):
301+
self.token = gen_dummy_access_token(sub=self.sub)
302+
ret = {
303+
"access_token": self.token,
304+
"expires_in": 3599,
305+
"token_type": "Bearer",
306+
}
307+
return build_response(json.dumps(ret).encode("utf-8"))
288308
else:
289309
# Reject malformed requests.
290310
raise HTTPError()

test/unit/test_auth_workload_identity.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
)
1717
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
1818

19-
from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token
19+
from ..csp_helpers import (
20+
FakeAwsEnvironment,
21+
FakeGceMetadataService,
22+
build_response,
23+
gen_dummy_access_token,
24+
gen_dummy_id_token,
25+
)
2026

2127
logger = logging.getLogger(__name__)
2228

@@ -288,7 +294,7 @@ def test_explicit_gcp_metadata_server_error_bubbles_up(exception):
288294
with pytest.raises(ProgrammingError) as excinfo:
289295
auth_class.prepare(conn=None)
290296

291-
assert "Error fetching GCP metadata:" in str(excinfo.value)
297+
assert "Error fetching GCP identity token:" in str(excinfo.value)
292298
assert "Ensure the application is running on GCP." in str(excinfo.value)
293299

294300

@@ -316,6 +322,44 @@ def test_explicit_gcp_generates_unique_assertion_content(
316322
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}'
317323

318324

325+
@mock.patch("snowflake.connector.session_manager.SessionManager.post")
326+
def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
327+
mock_post_request, fake_gce_metadata_service: FakeGceMetadataService
328+
):
329+
fake_gce_metadata_service.sub = "sa1"
330+
impersonation_path = ["sa2", "sa3"]
331+
sa1_access_token = gen_dummy_access_token("sa1")
332+
sa3_id_token = gen_dummy_id_token("sa3")
333+
334+
mock_post_request.return_value = build_response(
335+
json.dumps({"token": sa3_id_token}).encode("utf-8")
336+
)
337+
338+
auth_class = AuthByWorkloadIdentity(
339+
provider=AttestationProvider.GCP, impersonation_path=impersonation_path
340+
)
341+
auth_class.prepare(conn=None)
342+
343+
mock_post_request.assert_called_once_with(
344+
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa3:generateIdToken",
345+
headers={
346+
"Authorization": f"Bearer {sa1_access_token}",
347+
"Content-Type": "application/json",
348+
},
349+
json={
350+
"delegates": ["projects/-/serviceAccounts/sa2"],
351+
"audience": "snowflakecomputing.com",
352+
},
353+
)
354+
355+
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"sa3"}'
356+
assert extract_api_data(auth_class) == {
357+
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
358+
"PROVIDER": "GCP",
359+
"TOKEN": sa3_id_token,
360+
}
361+
362+
319363
# -- Azure Tests --
320364

321365

test/unit/test_connection.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def test_otel_error_message(caplog, mock_post_requests):
631631
"workload_identity_entra_resource",
632632
"api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
633633
),
634+
("workload_identity_impersonation_path", ["subject-b", "subject-c"]),
634635
],
635636
)
636637
def test_cannot_set_dependent_params_without_wlid_authenticator(
@@ -679,6 +680,71 @@ def test_workload_identity_provider_is_required_for_wif_authenticator(
679680
assert expected_error_msg in str(excinfo.value)
680681

681682

683+
@pytest.mark.parametrize(
684+
"provider_param",
685+
[
686+
# Strongly-typed values.
687+
AttestationProvider.AWS,
688+
AttestationProvider.AZURE,
689+
AttestationProvider.OIDC,
690+
# String values.
691+
"AWS",
692+
"AZURE",
693+
"OIDC",
694+
],
695+
)
696+
def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
697+
monkeypatch, provider_param
698+
):
699+
with monkeypatch.context() as m:
700+
m.setattr(
701+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
702+
)
703+
704+
with pytest.raises(ProgrammingError) as excinfo:
705+
snowflake.connector.connect(
706+
account="account",
707+
authenticator="WORKLOAD_IDENTITY",
708+
workload_identity_provider=provider_param,
709+
workload_identity_impersonation_path=[
710+
"sa2@project.iam.gserviceaccount.com"
711+
],
712+
)
713+
assert (
714+
"workload_identity_impersonation_path is currently only supported for GCP."
715+
in str(excinfo.value)
716+
)
717+
718+
719+
@pytest.mark.parametrize(
720+
"provider_param",
721+
[
722+
AttestationProvider.GCP,
723+
"GCP",
724+
],
725+
)
726+
def test_workload_identity_impersonation_path_supported_for_gcp_provider(
727+
monkeypatch, provider_param
728+
):
729+
with monkeypatch.context() as m:
730+
m.setattr(
731+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
732+
)
733+
734+
conn = snowflake.connector.connect(
735+
account="account",
736+
authenticator="WORKLOAD_IDENTITY",
737+
workload_identity_provider=provider_param,
738+
workload_identity_impersonation_path=[
739+
"sa2@project.iam.gserviceaccount.com"
740+
],
741+
)
742+
assert conn.auth_class.provider == AttestationProvider.GCP
743+
assert conn.auth_class.impersonation_path == [
744+
"sa2@project.iam.gserviceaccount.com"
745+
]
746+
747+
682748
@pytest.mark.parametrize(
683749
"provider_param, parsed_provider",
684750
[

test/wif/test_wif.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_should_authenticate_using_oidc():
5959

6060

6161
@pytest.mark.wif
62-
@pytest.mark.skip("Impersonation is still being developed")
6362
def test_should_authenticate_with_impersonation():
6463
if not isinstance(IMPERSONATION_PATH, str) or not IMPERSONATION_PATH:
6564
pytest.skip("Skipping test - IMPERSONATION_PATH is not set")

0 commit comments

Comments
 (0)