|
10 | 10 | # See the License for the specific language governing permissions and |
11 | 11 | # limitations under the License. |
12 | 12 | import json |
13 | | -import re |
14 | 13 | import threading |
15 | 14 | import time |
16 | 15 | import uuid |
17 | | -from collections import namedtuple |
18 | 16 | from unittest import mock |
19 | 17 | from urllib.parse import urlparse |
20 | 18 |
|
|
25 | 23 | from requests_kerberos.exceptions import KerberosExchangeError |
26 | 24 |
|
27 | 25 | import trino.exceptions |
| 26 | +from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \ |
| 27 | + MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \ |
| 28 | + SERVER_ADDRESS |
28 | 29 | from trino import constants |
29 | 30 | from trino.auth import KerberosAuthentication, _OAuth2TokenBearer |
30 | 31 | from trino.client import TrinoQuery, TrinoRequest, TrinoResult |
@@ -259,52 +260,6 @@ def long_call(request, uri, headers): |
259 | 260 | httpretty.reset() |
260 | 261 |
|
261 | 262 |
|
262 | | -SERVER_ADDRESS = "https://coordinator" |
263 | | -REDIRECT_PATH = "oauth2/initiate" |
264 | | -TOKEN_PATH = "oauth2/token" |
265 | | -REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}" |
266 | | -TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}" |
267 | | - |
268 | | - |
269 | | -class RedirectHandler: |
270 | | - def __init__(self): |
271 | | - self.redirect_server = "" |
272 | | - |
273 | | - def __call__(self, url): |
274 | | - self.redirect_server += url |
275 | | - |
276 | | - |
277 | | -class PostStatementCallback: |
278 | | - def __init__(self, redirect_server, token_server, tokens, sample_post_response_data): |
279 | | - self.redirect_server = redirect_server |
280 | | - self.token_server = token_server |
281 | | - self.tokens = tokens |
282 | | - self.sample_post_response_data = sample_post_response_data |
283 | | - |
284 | | - def __call__(self, request, uri, response_headers): |
285 | | - authorization = request.headers.get("Authorization") |
286 | | - if authorization and authorization.replace("Bearer ", "") in self.tokens: |
287 | | - return [200, response_headers, json.dumps(self.sample_post_response_data)] |
288 | | - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' |
289 | | - f'x_token_server="{self.token_server}"', |
290 | | - 'Basic realm': '"Trino"'}, ""] |
291 | | - |
292 | | - |
293 | | -class GetTokenCallback: |
294 | | - def __init__(self, token_server, token, attempts=1): |
295 | | - self.token_server = token_server |
296 | | - self.token = token |
297 | | - self.attempts = attempts |
298 | | - |
299 | | - def __call__(self, request, uri, response_headers): |
300 | | - self.attempts -= 1 |
301 | | - if self.attempts < 0: |
302 | | - return [404, response_headers, "{}"] |
303 | | - if self.attempts == 0: |
304 | | - return [200, response_headers, f'{{"token": "{self.token}"}}'] |
305 | | - return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}'] |
306 | | - |
307 | | - |
308 | 263 | @pytest.mark.parametrize("attempts", [1, 3, 5]) |
309 | 264 | @httprettified |
310 | 265 | def test_oauth2_authentication_flow(attempts, sample_post_response_data): |
@@ -511,57 +466,6 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon |
511 | 466 | assert len(_get_token_requests(challenge_id)) == 1 |
512 | 467 |
|
513 | 468 |
|
514 | | -class MultithreadedTokenServer: |
515 | | - Challenge = namedtuple('Challenge', ['token', 'attempts']) |
516 | | - |
517 | | - def __init__(self, sample_post_response_data, attempts=1): |
518 | | - self.tokens = set() |
519 | | - self.challenges = {} |
520 | | - self.sample_post_response_data = sample_post_response_data |
521 | | - self.attempts = attempts |
522 | | - |
523 | | - # bind post statement |
524 | | - httpretty.register_uri( |
525 | | - method=httpretty.POST, |
526 | | - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", |
527 | | - body=self.post_statement_callback) |
528 | | - |
529 | | - # bind get token |
530 | | - httpretty.register_uri( |
531 | | - method=httpretty.GET, |
532 | | - uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), |
533 | | - body=self.get_token_callback) |
534 | | - |
535 | | - # noinspection PyUnusedLocal |
536 | | - def post_statement_callback(self, request, uri, response_headers): |
537 | | - authorization = request.headers.get("Authorization") |
538 | | - |
539 | | - if authorization and authorization.replace("Bearer ", "") in self.tokens: |
540 | | - return [200, response_headers, json.dumps(self.sample_post_response_data)] |
541 | | - |
542 | | - challenge_id = str(uuid.uuid4()) |
543 | | - token = str(uuid.uuid4()) |
544 | | - self.tokens.add(token) |
545 | | - self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts) |
546 | | - redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" |
547 | | - token_server = f"{TOKEN_RESOURCE}/{challenge_id}" |
548 | | - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", ' |
549 | | - f'x_token_server="{token_server}"', |
550 | | - 'Basic realm': '"Trino"'}, ""] |
551 | | - |
552 | | - # noinspection PyUnusedLocal |
553 | | - def get_token_callback(self, request, uri, response_headers): |
554 | | - challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "") |
555 | | - challenge = self.challenges[challenge_id] |
556 | | - challenge = challenge._replace(attempts=challenge.attempts - 1) |
557 | | - self.challenges[challenge_id] = challenge |
558 | | - if challenge.attempts < 0: |
559 | | - return [404, response_headers, "{}"] |
560 | | - if challenge.attempts == 0: |
561 | | - return [200, response_headers, f'{{"token": "{challenge.token}"}}'] |
562 | | - return [200, response_headers, f'{{"nextUri": "{uri}"}}'] |
563 | | - |
564 | | - |
565 | 469 | @httprettified |
566 | 470 | def test_multithreaded_oauth2_authentication_flow(sample_post_response_data): |
567 | 471 | redirect_handler = RedirectHandler() |
@@ -598,31 +502,19 @@ def run(self) -> None: |
598 | 502 | for thread in threads: |
599 | 503 | thread.join() |
600 | 504 |
|
601 | | - # should issue only 3 tokens and each thread should get one |
602 | | - assert len(token_server.tokens) == 3 |
| 505 | + # should issue only 1 token and each thread should reuse it |
| 506 | + assert len(token_server.tokens) == 1 |
603 | 507 | for thread in threads: |
604 | 508 | assert thread.token in token_server.tokens |
605 | 509 |
|
606 | | - # should start only 3 challenges and every token should be obtained |
607 | | - assert len(token_server.challenges.keys()) == 3 |
| 510 | + # should start only 1 challenge |
| 511 | + assert len(token_server.challenges.keys()) == 1 |
608 | 512 | for challenge_id, challenge in token_server.challenges.items(): |
609 | 513 | assert f"{REDIRECT_RESOURCE}/{challenge_id}" in redirect_handler.redirect_server |
610 | 514 | assert challenge.attempts == 0 |
611 | 515 | assert len(_get_token_requests(challenge_id)) == 1 |
612 | 516 | # 3 threads * (10 POST /statement each + 1 replied request by authentication) |
613 | | - assert len(_post_statement_requests()) == 33 |
614 | | - |
615 | | - |
616 | | -def _get_token_requests(challenge_id): |
617 | | - return list(filter( |
618 | | - lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", |
619 | | - httpretty.latest_requests())) |
620 | | - |
621 | | - |
622 | | -def _post_statement_requests(): |
623 | | - return list(filter( |
624 | | - lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, |
625 | | - httpretty.latest_requests())) |
| 517 | + assert len(_post_statement_requests()) == 31 |
626 | 518 |
|
627 | 519 |
|
628 | 520 | @mock.patch("trino.client.TrinoRequest.http") |
|
0 commit comments