1414import jwt
1515import pytest
1616
17- from snowflake .connector .aio ._wif_util import AttestationProvider
17+ from snowflake .connector .aio ._wif_util import (
18+ AttestationProvider ,
19+ WorkloadIdentityAttestation ,
20+ )
1821from snowflake .connector .aio .auth import AuthByWorkloadIdentity
1922from snowflake .connector .errors import ProgrammingError
2023
2124from ...csp_helpers import gen_dummy_access_token , gen_dummy_id_token
25+ from ...helpers import apply_auth_class_update_body_async , create_mock_auth_body
2226from .csp_helpers_async import FakeAwsEnvironmentAsync , FakeGceMetadataServiceAsync
2327
2428logger = logging .getLogger (__name__ )
@@ -137,6 +141,42 @@ async def mock_post(*args, **kwargs):
137141 await connection .close ()
138142
139143
144+ @pytest .mark .parametrize (
145+ "provider,additional_args" ,
146+ [
147+ (AttestationProvider .AWS , {}),
148+ (AttestationProvider .GCP , {}),
149+ (AttestationProvider .AZURE , {}),
150+ (
151+ AttestationProvider .OIDC ,
152+ {"token" : gen_dummy_id_token (sub = "service-1" , iss = "issuer-1" )},
153+ ),
154+ ],
155+ )
156+ async def test_auth_prepare_body_does_not_overwrite_client_environment_fields (
157+ provider , additional_args
158+ ):
159+ auth_class = AuthByWorkloadIdentity (provider = provider , ** additional_args )
160+ auth_class .attestation = WorkloadIdentityAttestation (
161+ provider = AttestationProvider .GCP ,
162+ credential = None ,
163+ user_identifier_components = None ,
164+ )
165+
166+ req_body_before = create_mock_auth_body ()
167+ req_body_after = await apply_auth_class_update_body_async (
168+ auth_class , req_body_before
169+ )
170+
171+ assert all (
172+ [
173+ req_body_before ["data" ]["CLIENT_ENVIRONMENT" ][k ]
174+ == req_body_after ["data" ]["CLIENT_ENVIRONMENT" ][k ]
175+ for k in req_body_before ["data" ]["CLIENT_ENVIRONMENT" ]
176+ ]
177+ )
178+
179+
140180# -- OIDC Tests --
141181
142182
@@ -151,6 +191,7 @@ async def test_explicit_oidc_valid_inline_token_plumbed_to_api():
151191 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
152192 "PROVIDER" : "OIDC" ,
153193 "TOKEN" : dummy_token ,
194+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 0 },
154195 }
155196
156197
@@ -208,6 +249,9 @@ async def test_explicit_aws_encodes_audience_host_signature_to_api(
208249 data = await extract_api_data (auth_class )
209250 assert data ["AUTHENTICATOR" ] == "WORKLOAD_IDENTITY"
210251 assert data ["PROVIDER" ] == "AWS"
252+ assert (
253+ data ["CLIENT_ENVIRONMENT" ]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" ] == 0
254+ )
211255 verify_aws_token (data ["TOKEN" ], fake_aws_environment .region )
212256
213257
@@ -309,6 +353,7 @@ async def test_explicit_gcp_plumbs_token_to_api(
309353 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
310354 "PROVIDER" : "GCP" ,
311355 "TOKEN" : fake_gce_metadata_service .token ,
356+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 0 },
312357 }
313358
314359
@@ -365,6 +410,7 @@ def __init__(self, content):
365410 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
366411 "PROVIDER" : "GCP" ,
367412 "TOKEN" : sa3_id_token ,
413+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 2 },
368414 }
369415
370416
@@ -419,6 +465,7 @@ async def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
419465 "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
420466 "PROVIDER" : "AZURE" ,
421467 "TOKEN" : fake_azure_metadata_service .token ,
468+ "CLIENT_ENVIRONMENT" : {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH" : 0 },
422469 }
423470
424471
0 commit comments