Skip to content

Commit 13b7365

Browse files
sfc-gh-pczajkasfc-gh-turbaszek
authored andcommitted
[async] Add WIF impersonation path length as data sent to Snowflake backend (#2521)
1 parent 7976686 commit 13b7365

10 files changed

+222
-1
lines changed

test/helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,9 @@ def apply_auth_class_update_body(auth_class, req_body_before):
340340
req_body_after = copy.deepcopy(req_body_before)
341341
auth_class.update_body(req_body_after)
342342
return req_body_after
343+
344+
345+
async def apply_auth_class_update_body_async(auth_class, req_body_before):
346+
req_body_after = copy.deepcopy(req_body_before)
347+
await auth_class.update_body(req_body_after)
348+
return req_body_after

test/unit/aio/test_auth_async.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import inspect
1010
import sys
11+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
1112
from test.unit.aio.mock_utils import mock_connection
1213
from unittest.mock import Mock, PropertyMock
1314

@@ -340,3 +341,21 @@ def test_mro():
340341
assert AuthByDefault.mro().index(AuthByPluginAsync) < AuthByDefault.mro().index(
341342
AuthByPluginSync
342343
)
344+
345+
346+
async def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields():
347+
password = "testpassword"
348+
auth_class = AuthByDefault(password)
349+
350+
req_body_before = create_mock_auth_body()
351+
req_body_after = await apply_auth_class_update_body_async(
352+
auth_class, req_body_before
353+
)
354+
355+
assert all(
356+
[
357+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
358+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
359+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
360+
]
361+
)

test/unit/aio/test_auth_keypair_async.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
89
from test.unit.aio.mock_utils import mock_connection
910
from unittest.mock import Mock, PropertyMock, patch
1011

@@ -61,6 +62,24 @@ async def test_auth_keypair(authenticator):
6162
assert rest.master_token == "MASTER_TOKEN"
6263

6364

65+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
66+
private_key_der, _ = generate_key_pair(2048)
67+
auth_class = AuthByKeyPair(private_key=private_key_der)
68+
69+
req_body_before = create_mock_auth_body()
70+
req_body_after = await apply_auth_class_update_body_async(
71+
auth_class, req_body_before
72+
)
73+
74+
assert all(
75+
[
76+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
77+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
78+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
79+
]
80+
)
81+
82+
6483
async def test_auth_keypair_abc():
6584
"""Simple Key Pair test using abstraction layer."""
6685
private_key_der, public_key_der_encoded = generate_key_pair(2048)

test/unit/aio/test_auth_oauth_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
9+
810
import pytest
911

1012
from snowflake.connector.aio.auth import AuthByOAuth
@@ -20,6 +22,24 @@ async def test_auth_oauth():
2022
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
2123

2224

25+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
26+
token = "oAuthToken"
27+
auth_class = AuthByOAuth(token)
28+
29+
req_body_before = create_mock_auth_body()
30+
req_body_after = await apply_auth_class_update_body_async(
31+
auth_class, req_body_before
32+
)
33+
34+
assert all(
35+
[
36+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
37+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
38+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
39+
]
40+
)
41+
42+
2343
@pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"])
2444
async def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator):
2545
"""Test that oauth authenticator is case insensitive."""

test/unit/aio/test_auth_oauth_auth_code_async.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55

66
import unittest.mock as mock
7+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
78
from unittest.mock import patch
89

910
import pytest
@@ -44,6 +45,34 @@ async def test_auth_oauth_auth_code_oauth_type(omit_oauth_urls_check):
4445
)
4546

4647

48+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(
49+
omit_oauth_urls_check,
50+
):
51+
auth_class = AuthByOauthCode(
52+
"app",
53+
"clientId",
54+
"clientSecret",
55+
"auth_url",
56+
"tokenRequestUrl",
57+
"redirectUri:{port}",
58+
"scope",
59+
"host",
60+
)
61+
62+
req_body_before = create_mock_auth_body()
63+
req_body_after = await apply_auth_class_update_body_async(
64+
auth_class, req_body_before
65+
)
66+
67+
assert all(
68+
[
69+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
70+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
71+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
72+
]
73+
)
74+
75+
4776
@pytest.mark.parametrize("rtr_enabled", [True, False])
4877
async def test_auth_oauth_auth_code_single_use_refresh_tokens(
4978
rtr_enabled: bool, omit_oauth_urls_check

test/unit/aio/test_auth_oauth_credentials_async.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
9+
810
import pytest
911

1012
from snowflake.connector.aio.auth import AuthByOauthCredentials
@@ -27,6 +29,29 @@ async def test_auth_oauth_credentials_oauth_type():
2729
)
2830

2931

32+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
33+
auth_class = AuthByOauthCredentials(
34+
"app",
35+
"clientId",
36+
"clientSecret",
37+
"https://example.com/oauth/token",
38+
"scope",
39+
)
40+
41+
req_body_before = create_mock_auth_body()
42+
req_body_after = await apply_auth_class_update_body_async(
43+
auth_class, req_body_before
44+
)
45+
46+
assert all(
47+
[
48+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
49+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
50+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
51+
]
52+
)
53+
54+
3055
@pytest.mark.parametrize(
3156
"authenticator, oauth_credentials_in_body",
3257
[

test/unit/aio/test_auth_okta_async.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import logging
9+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
910
from test.unit.aio.mock_utils import mock_connection
1011
from unittest.mock import MagicMock, Mock, PropertyMock, patch
1112

@@ -18,6 +19,24 @@
1819
from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION
1920

2021

22+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
23+
application = "testapplication"
24+
auth_class = AuthByOkta(application)
25+
26+
req_body_before = create_mock_auth_body()
27+
req_body_after = await apply_auth_class_update_body_async(
28+
auth_class, req_body_before
29+
)
30+
31+
assert all(
32+
[
33+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
34+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
35+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
36+
]
37+
)
38+
39+
2140
async def test_auth_okta():
2241
"""Authentication by OKTA positive test case."""
2342
authenticator = "https://testsso.snowflake.net/"

test/unit/aio/test_auth_pat_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
9+
810
import pytest
911

1012
from snowflake.connector.aio.auth import AuthByPAT
@@ -27,6 +29,24 @@ async def test_auth_pat():
2729
assert auth.assertion_content is None
2830

2931

32+
async def test_pat_prepare_body_does_not_overwrite_client_environment_fields():
33+
token = "patToken"
34+
auth_class = AuthByPAT(token)
35+
36+
req_body_before = create_mock_auth_body()
37+
req_body_after = await apply_auth_class_update_body_async(
38+
auth_class, req_body_before
39+
)
40+
41+
assert all(
42+
[
43+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
44+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
45+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
46+
]
47+
)
48+
49+
3050
async def test_auth_pat_reauthenticate():
3151
"""Test PAT reauthenticate."""
3252
token = "patToken"

test/unit/aio/test_auth_webbrowser_async.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import base64
1010
import socket
11+
from test.helpers import apply_auth_class_update_body_async, create_mock_auth_body
1112
from test.unit.aio.mock_utils import mock_connection
1213
from unittest import mock
1314
from unittest.mock import MagicMock, Mock, PropertyMock, patch
@@ -1008,6 +1009,22 @@ async def mock_webbrowser_auth_prepare(
10081009
await conn.close()
10091010

10101011

1012+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields():
1013+
auth_class = AuthByWebBrowser(application=APPLICATION)
1014+
req_body_before = create_mock_auth_body()
1015+
req_body_after = await apply_auth_class_update_body_async(
1016+
auth_class, req_body_before
1017+
)
1018+
1019+
assert all(
1020+
[
1021+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
1022+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
1023+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
1024+
]
1025+
)
1026+
1027+
10111028
def test_mro():
10121029
"""Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
10131030
from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync

test/unit/aio/test_auth_workload_identity_async.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
import jwt
1515
import pytest
1616

17-
from snowflake.connector.aio._wif_util import AttestationProvider
17+
from snowflake.connector.aio._wif_util import (
18+
AttestationProvider,
19+
WorkloadIdentityAttestation,
20+
)
1821
from snowflake.connector.aio.auth import AuthByWorkloadIdentity
1922
from snowflake.connector.errors import ProgrammingError
2023

2124
from ...csp_helpers import gen_dummy_access_token, gen_dummy_id_token
25+
from ...helpers import apply_auth_class_update_body_async, create_mock_auth_body
2226
from .csp_helpers_async import FakeAwsEnvironmentAsync, FakeGceMetadataServiceAsync
2327

2428
logger = logging.getLogger(__name__)
@@ -137,6 +141,42 @@ async def mock_post(*args, **kwargs):
137141
await connection.close()
138142

139143

144+
@pytest.mark.parametrize(
145+
"provider,additional_args",
146+
[
147+
(AttestationProvider.AWS, {}),
148+
(AttestationProvider.GCP, {}),
149+
(AttestationProvider.AZURE, {}),
150+
(
151+
AttestationProvider.OIDC,
152+
{"token": gen_dummy_id_token(sub="service-1", iss="issuer-1")},
153+
),
154+
],
155+
)
156+
async def test_auth_prepare_body_does_not_overwrite_client_environment_fields(
157+
provider, additional_args
158+
):
159+
auth_class = AuthByWorkloadIdentity(provider=provider, **additional_args)
160+
auth_class.attestation = WorkloadIdentityAttestation(
161+
provider=AttestationProvider.GCP,
162+
credential=None,
163+
user_identifier_components=None,
164+
)
165+
166+
req_body_before = create_mock_auth_body()
167+
req_body_after = await apply_auth_class_update_body_async(
168+
auth_class, req_body_before
169+
)
170+
171+
assert all(
172+
[
173+
req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
174+
== req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
175+
for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
176+
]
177+
)
178+
179+
140180
# -- OIDC Tests --
141181

142182

@@ -151,6 +191,7 @@ async def test_explicit_oidc_valid_inline_token_plumbed_to_api():
151191
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
152192
"PROVIDER": "OIDC",
153193
"TOKEN": dummy_token,
194+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
154195
}
155196

156197

@@ -208,6 +249,9 @@ async def test_explicit_aws_encodes_audience_host_signature_to_api(
208249
data = await extract_api_data(auth_class)
209250
assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY"
210251
assert data["PROVIDER"] == "AWS"
252+
assert (
253+
data["CLIENT_ENVIRONMENT"]["WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"] == 0
254+
)
211255
verify_aws_token(data["TOKEN"], fake_aws_environment.region)
212256

213257

@@ -309,6 +353,7 @@ async def test_explicit_gcp_plumbs_token_to_api(
309353
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
310354
"PROVIDER": "GCP",
311355
"TOKEN": fake_gce_metadata_service.token,
356+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
312357
}
313358

314359

@@ -365,6 +410,7 @@ def __init__(self, content):
365410
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
366411
"PROVIDER": "GCP",
367412
"TOKEN": sa3_id_token,
413+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 2},
368414
}
369415

370416

@@ -419,6 +465,7 @@ async def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service):
419465
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
420466
"PROVIDER": "AZURE",
421467
"TOKEN": fake_azure_metadata_service.token,
468+
"CLIENT_ENVIRONMENT": {"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH": 0},
422469
}
423470

424471

0 commit comments

Comments
 (0)