From c76447643c692fdb675e52487d234469df5dda52 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Fri, 7 Nov 2025 23:51:01 +0000 Subject: [PATCH] [Identity] Add async token method concurrency control Implements per-scope lock mechanism to prevent concurrent token requests for the same token request argument combination, reducing unnecessary network calls and improving performance in high-concurrency async scenarios. Signed-off-by: Paul Van Eck --- sdk/identity/azure-identity/CHANGELOG.md | 2 + .../identity/aio/_internal/get_token_mixin.py | 83 +++- .../azure/identity/aio/_internal/utils.py | 29 ++ .../tests/test_get_token_mixin_async.py | 393 +++++++++++++++++- 4 files changed, 492 insertions(+), 15 deletions(-) create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_internal/utils.py diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 399f2c761a08..663da5874dfa 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -10,6 +10,8 @@ ### Other Changes +- Add lock-based concurrency control to asynchronous token requests. This addresses race conditions where multiple coroutines could simultaneously request tokens for the same scopes, resulting in redundant authentication requests. ([#43889](https://github.com/Azure/azure-sdk-for-python/pull/43889)) + ## 1.26.0b1 (2025-11-07) ### Features Added diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py index 8b383fef9818..0831f955b618 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py @@ -4,12 +4,15 @@ # ------------------------------------ import abc import logging +import threading import time -from typing import Any, Optional +from typing import Any, Optional, Dict, Type +from weakref import WeakValueDictionary from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from ..._internal import within_credential_chain +from .utils import get_running_async_lock_class _LOGGER = logging.getLogger(__name__) @@ -18,9 +21,39 @@ class GetTokenMixin(abc.ABC): def __init__(self, *args: Any, **kwargs: Any) -> None: self._last_request_time = 0 + self._global_lock: Optional[Any] = None + self._global_lock_init_lock = threading.Lock() + self._active_locks: WeakValueDictionary[tuple, Any] = WeakValueDictionary() + self._lock_class_type: Optional[Type] = None + # https://github.com/python/mypy/issues/5887 super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore + @property + def _lock_class(self) -> Type: + if self._lock_class_type is None: + self._lock_class_type = get_running_async_lock_class() + return self._lock_class_type + + async def _get_request_lock(self, lock_key: tuple) -> Any: + # Initialize global lock if needed, using threading.Lock for thread-safe initialization + if self._global_lock is None: + with self._global_lock_init_lock: + if self._global_lock is None: + self._global_lock = self._lock_class() + + lock = self._active_locks.get(lock_key) + if lock is not None: + return lock + + async with self._global_lock: + # Double-check in case another coroutine created it while we waited + lock = self._active_locks.get(lock_key) + if lock is None: + lock = self._lock_class() + self._active_locks[lock_key] = lock + return lock + @abc.abstractmethod async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: """Attempt to acquire an access token from a cache or by redeeming a refresh token. @@ -132,19 +165,29 @@ async def _get_token_base( token = await self._acquire_token_silently( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) - if not token: - self._last_request_time = int(time.time()) - token = await self._request_token( - *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs - ) - elif self._should_refresh(token): - try: - self._last_request_time = int(time.time()) - token = await self._request_token( + if not token or self._should_refresh(token): + # Get the lock specific to this scope combination + lock_key = (tuple(sorted(scopes)), claims, tenant_id, enable_cae) + lock = await self._get_request_lock(lock_key) + + async with lock: + # Double-check in case another coroutine refreshed the token while we waited for the lock + current_token = await self._acquire_token_silently( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) - except Exception: # pylint:disable=broad-except - pass + if current_token and not self._should_refresh(current_token): + token = current_token + else: + try: + token = await self._request_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) + except Exception: # pylint:disable=broad-except + self._last_request_time = int(time.time()) + # Only raise if we don't have a token to return + if not token: + raise + _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, "%s.%s succeeded", @@ -163,3 +206,19 @@ async def _get_token_base( exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) raise + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + # Remove the non-picklable entries + del state["_global_lock"] + del state["_lock_class_type"] + del state["_global_lock_init_lock"] + del state["_active_locks"] + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + self._global_lock_init_lock = threading.Lock() + self._active_locks = WeakValueDictionary() + self._global_lock = None + self._lock_class_type = None diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/utils.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/utils.py new file mode 100644 index 000000000000..45e2b92d9375 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/utils.py @@ -0,0 +1,29 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import sys +from typing import Type + + +def get_running_async_lock_class() -> Type: + """Get a lock class from the async library that the current context is running under. + + :return: The running async library's Lock class. + :rtype: Type[Lock] + :raises RuntimeError: if the current context is not running under an async library. + """ + + try: + import asyncio # pylint: disable=do-not-import-asyncio + + # Check if we are running in an asyncio event loop. + asyncio.get_running_loop() + return asyncio.Lock + except RuntimeError as err: + # Otherwise, assume we are running in a trio event loop if it has already been imported. + if "trio" in sys.modules: + import trio # pylint: disable=networking-import-outside-azure-core-transport + + return trio.Lock + raise RuntimeError("An asyncio or trio event loop is required.") from err diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py index 4258a4684148..3ffb9ece39aa 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py @@ -3,6 +3,10 @@ # Licensed under the MIT License. # ------------------------------------ import time +import asyncio +from asyncio import sleep as real_sleep +import gc +import weakref from unittest import mock from azure.core.credentials import AccessTokenInfo @@ -38,6 +42,23 @@ async def get_token_info(self, *_, **__): return await super().get_token_info(*_, **__) +class MockCredentialWithDelay(GetTokenMixin): + def __init__(self, cached_token=None): + super().__init__() + self.cached_token = cached_token + self.request_count = 0 + + async def _acquire_token_silently(self, *scopes, **kwargs): + return self.cached_token + + async def _request_token(self, *scopes, **kwargs): + self.request_count += 1 + request_count = self.request_count + await real_sleep(0.2) # Simulate network delay, give other coroutines time to queue up + self.cached_token = AccessTokenInfo(f"token_{request_count}", int(time.time() + 3600)) + return self.cached_token + + CACHED_TOKEN = "cached token" SCOPE = "scope" @@ -49,7 +70,9 @@ async def test_no_cached_token(get_token_method): credential = MockCredential() token = await getattr(credential, get_token_method)(SCOPE) - credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) + # Due to double-checking pattern in concurrency control, _acquire_token_silently may be called twice + assert credential.acquire_token_silently.call_count >= 1 + credential.acquire_token_silently.assert_any_call(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token @@ -86,7 +109,9 @@ async def test_expired_token(get_token_method): credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1)) token = await getattr(credential, get_token_method)(SCOPE) - credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) + # Due to double-checking pattern in concurrency control, _acquire_token_silently may be called multiple times + assert credential.acquire_token_silently.call_count >= 1 + credential.acquire_token_silently.assert_any_call(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token @@ -114,7 +139,9 @@ async def test_cached_token_within_refresh_window(get_token_method): ) token = await getattr(credential, get_token_method)(SCOPE) - credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) + # Due to double-checking pattern in concurrency control, _acquire_token_silently may be called multiple times + assert credential.acquire_token_silently.call_count >= 1 + credential.acquire_token_silently.assert_any_call(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token @@ -133,3 +160,363 @@ async def test_retry_delay(get_token_method): assert token.token == CACHED_TOKEN credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_concurrent_token_requests(get_token_method): + """When multiple coroutines request tokens concurrently, only one token request should be made""" + credential = MockCredentialWithDelay() + + # Launch multiple concurrent token requests + tasks = [getattr(credential, get_token_method)(SCOPE) for _ in range(5)] + tokens = await asyncio.gather(*tasks) + + # All tasks should get the same token + for token in tokens: + assert token.token == "token_1" + + # Only one token request should have been made despite 5 concurrent calls + assert credential.request_count == 1 + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_concurrent_token_refresh(get_token_method): + """When multiple coroutines need to refresh tokens concurrently, only one refresh should happen""" + + # Create a credential with a token that needs refresh + old_token = AccessTokenInfo("old_token", int(time.time() + DEFAULT_REFRESH_OFFSET - 1)) + credential = MockCredentialWithDelay(old_token) + + # Launch multiple concurrent token requests that need refresh + tasks = [getattr(credential, get_token_method)(SCOPE) for _ in range(5)] + tokens = await asyncio.gather(*tasks) + + # All tasks should get the refreshed token + for token in tokens: + assert token.token == "token_1" + + # Only one refresh should have been made despite 5 concurrent calls needing refresh + assert credential.request_count == 1 + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_concurrent_different_scopes_run_independently(get_token_method): + """When multiple coroutines request different scopes concurrently, they should run independently""" + + class MockCredentialWithScopeCache(GetTokenMixin): + def __init__(self): + super().__init__() + self.cached_tokens = {} # Store tokens by scope + self.request_count = 0 + + async def _acquire_token_silently(self, *scopes, **kwargs): + lock_key = tuple(sorted(scopes)) + return self.cached_tokens.get(lock_key) + + async def _request_token(self, *scopes, **kwargs): + self.request_count += 1 + request_count = self.request_count + await real_sleep(0.2) + lock_key = tuple(sorted(scopes)) + token = AccessTokenInfo(f"token_{lock_key[0]}_{request_count}", int(time.time() + 3600)) + self.cached_tokens[lock_key] = token + return token + + credential = MockCredentialWithScopeCache() + + # Create different scope combinations + scope1 = "scope1" + scope2 = "scope2" + scope3 = "scope3" + + # Launch concurrent requests for different scopes - these should NOT wait on each other + tasks = [ + getattr(credential, get_token_method)(scope1), + getattr(credential, get_token_method)(scope1), # Same scope - should wait + getattr(credential, get_token_method)(scope2), # Different scope - should run independently + getattr(credential, get_token_method)(scope2), # Same scope - should wait + getattr(credential, get_token_method)(scope3), # Different scope - should run independently + ] + + tokens = await asyncio.gather(*tasks) + + # Should have made 3 requests total (one for each unique scope) + assert credential.request_count == 3 + + # Check that tokens for the same scope are identical + assert tokens[0].token == tokens[1].token # scope1 tokens should be the same + assert tokens[2].token == tokens[3].token # scope2 tokens should be the same + + # Check that tokens for different scopes are different + assert tokens[0].token != tokens[2].token # scope1 != scope2 + assert tokens[0].token != tokens[4].token # scope1 != scope3 + assert tokens[2].token != tokens[4].token # scope2 != scope3 + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_concurrent_different_options_run_independently(get_token_method): + """When multiple coroutines request tokens with different options, they should run independently""" + + class RealisticMockCredential(GetTokenMixin): + def __init__(self): + super().__init__() + self.cached_tokens = {} # Store tokens by options key + self.request_count = 0 + + async def _acquire_token_silently(self, *scopes, **kwargs): + # Create a key based on all parameters that matter for caching + key = ( + tuple(sorted(scopes)), + kwargs.get("claims"), + kwargs.get("tenant_id"), + kwargs.get("enable_cae", False), + ) + + return self.cached_tokens.get(key) + + async def _request_token(self, *scopes, **kwargs): + self.request_count += 1 + request_count = self.request_count + await real_sleep(0.2) # Simulate network delay + key = ( + tuple(sorted(scopes)), + kwargs.get("claims"), + kwargs.get("tenant_id"), + kwargs.get("enable_cae", False), + ) + token = AccessTokenInfo(f"token_{request_count}", int(time.time() + 3600)) + self.cached_tokens[key] = token + return token + + credential = RealisticMockCredential() + + # Create tasks with different options that should run independently + if get_token_method == "get_token": + tasks = [ + credential.get_token(SCOPE), # No options + credential.get_token(SCOPE), # Same - should wait + credential.get_token(SCOPE, tenant_id="tenant1"), # Different tenant - should run independently + credential.get_token(SCOPE, tenant_id="tenant1"), # Same tenant - should wait + credential.get_token(SCOPE, claims="claim1"), # Different claims - should run independently + credential.get_token(SCOPE, enable_cae=True), # Different enable_cae - should run independently + ] + else: # get_token_info + tasks = [ + credential.get_token_info(SCOPE), # No options + credential.get_token_info(SCOPE), # Same - should wait + credential.get_token_info( + SCOPE, options={"tenant_id": "tenant1"} + ), # Different tenant - should run independently + credential.get_token_info(SCOPE, options={"tenant_id": "tenant1"}), # Same tenant - should wait + credential.get_token_info( + SCOPE, options={"claims": "claim1"} + ), # Different claims - should run independently + credential.get_token_info( + SCOPE, options={"enable_cae": True} + ), # Different enable_cae - should run independently + ] + + tokens = await asyncio.gather(*tasks) + + # Should have made 4 requests total (one for each unique option combination) + assert credential.request_count == 4 + + # Check that tokens for the same options are identical + assert tokens[0].token == tokens[1].token # Same options + assert tokens[2].token == tokens[3].token # Same tenant_id + + # Check that tokens for different options are different + assert tokens[0].token != tokens[2].token # No options != tenant_id + assert tokens[0].token != tokens[4].token # No options != claims + assert tokens[2].token != tokens[4].token # tenant_id != claims + assert tokens[0].token != tokens[5].token # No options != enable_cae + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_weakref_dictionary_cleanup(get_token_method): + """Test that locks are automatically cleaned up from the WeakValueDictionary when no longer referenced""" + + class MockCredentialWithLockTracking(GetTokenMixin): + def __init__(self): + super().__init__() + self.cached_tokens = {} + self.request_count = 0 + self.lock_refs = [] # Store strong references during requests + + async def _acquire_token_silently(self, *scopes, **kwargs): + return None + + async def _request_token(self, *scopes, **kwargs): + self.request_count += 1 + # Capture a reference to the current scope's lock + lock_key = ( + tuple(sorted(scopes)), + kwargs.get("claims"), + kwargs.get("tenant_id"), + kwargs.get("enable_cae", False), + ) + if lock_key in self._active_locks: + self.lock_refs.append(self._active_locks[lock_key]) + await asyncio.sleep(0.1) + lock_key_for_token = tuple(sorted(scopes)) + token = AccessTokenInfo(f"token_{lock_key_for_token[0]}", int(time.time() + 3600)) + self.cached_tokens[lock_key_for_token] = token + return token + + credential = MockCredentialWithLockTracking() + + # Request tokens for multiple scopes concurrently to create locks + scope1 = "scope1" + scope2 = "scope2" + scope3 = "scope3" + + # Start concurrent requests - these will create locks and hold them during _request_token + tasks = [ + getattr(credential, get_token_method)(scope1), + getattr(credential, get_token_method)(scope2), + getattr(credential, get_token_method)(scope3), + ] + + # Wait for all requests to complete + await asyncio.gather(*tasks) + + # At this point, lock_refs contains strong references to the locks + # that were created during the requests + initial_lock_count = len(credential.lock_refs) + assert initial_lock_count == 3, f"Should have captured 3 lock references, got {initial_lock_count}" + + assert len(credential._active_locks) == 3, "WeakValueDictionary should contain 3 locks" + + # Create weak references to track when locks are deallocated + weak_refs = [weakref.ref(lock) for lock in credential.lock_refs] + + # Verify all locks are still alive (because we hold strong refs in lock_refs) + assert all(ref() is not None for ref in weak_refs), "All locks should be alive while strong refs exist" + + # Clear the strong references + credential.lock_refs.clear() + + # Force garbage collection + gc.collect() + + # After GC, the weak references should be dead + assert all(ref() is None for ref in weak_refs), "All locks should be GC'd after strong refs are removed" + + # The WeakValueDictionary should automatically clean up dead entries + # Access the dict to trigger cleanup + remaining_locks = len(credential._active_locks) + + assert remaining_locks == 0, f"WeakValueDictionary should be empty after GC, but has {remaining_locks} entries" + + # Verify we can still use the credential and create new locks + credential.lock_refs.clear() + await getattr(credential, get_token_method)(scope1) + + # A new lock should be created + assert len(credential.lock_refs) == 1, "Should be able to create new locks after cleanup" + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_concurrent_refresh_with_failure_fallback(get_token_method): + """When refresh fails during concurrent requests, all should get the old token as fallback""" + + class MockCredentialWithFailure(GetTokenMixin): + def __init__(self, cached_token): + super().__init__() + self.cached_token = cached_token + self.request_count = 0 + + async def _acquire_token_silently(self, *scopes, **kwargs): + return self.cached_token + + async def _request_token(self, *scopes, **kwargs): + self.request_count += 1 + await real_sleep(0.2) + raise Exception("Network error during refresh") + + # Token that needs refresh but is not expired + old_token = AccessTokenInfo("fallback_token", int(time.time() + DEFAULT_REFRESH_OFFSET - 1)) + credential = MockCredentialWithFailure(old_token) + + # Launch concurrent requests - all should get fallback token + tasks = [getattr(credential, get_token_method)(SCOPE) for _ in range(5)] + tokens = await asyncio.gather(*tasks) + + # All should receive the fallback token + for token in tokens: + assert token.token == "fallback_token" + + # Only one refresh attempt should have been made + assert credential.request_count == 1 + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_concurrent_refresh_failure_no_fallback(get_token_method): + """When refresh fails with no cached token, all concurrent requests should get the exception""" + + class MockCredentialWithFailure(GetTokenMixin): + def __init__(self): + super().__init__() + self.request_count = 0 + + async def _acquire_token_silently(self, *scopes, **kwargs): + return None + + async def _request_token(self, *scopes, **kwargs): + self.request_count += 1 + await real_sleep(0.2) + raise ValueError("Authentication failed") + + credential = MockCredentialWithFailure() + + # Launch concurrent requests - all should hit the exception + tasks = [getattr(credential, get_token_method)(SCOPE) for _ in range(3)] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # All results should be ValueError exceptions + assert all(isinstance(r, ValueError) for r in results) + assert all(str(r) == "Authentication failed" for r in results) + assert credential.request_count == 3 + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_with_refresh_on_attribute(get_token_method): + """Test that tokens with refresh_on attribute are handled correctly in concurrent scenarios""" + + class MockCredentialWithRefreshOn(GetTokenMixin): + def __init__(self, cached_token): + super().__init__() + self.cached_token = cached_token + self.request_count = 0 + + async def _acquire_token_silently(self, *scopes, **kwargs): + return self.cached_token + + async def _request_token(self, *scopes, **kwargs): + self.request_count += 1 + request_count = self.request_count + await real_sleep(0.2) + # Create token with refresh_on set to future time + token = AccessTokenInfo( + f"token_{request_count}", + int(time.time() + 3600), + refresh_on=int(time.time() + 1800), # Refresh halfway through validity + ) + self.cached_token = token + return token + + # Create token with refresh_on in the past (needs immediate refresh) + old_token = AccessTokenInfo("old_token", int(time.time() + 3600), refresh_on=int(time.time() - 1)) + credential = MockCredentialWithRefreshOn(old_token) + + # Launch concurrent requests + tasks = [getattr(credential, get_token_method)(SCOPE) for _ in range(5)] + tokens = await asyncio.gather(*tasks) + + # All should get the new token + for token in tokens: + assert token.token == "token_1" + + # Only one refresh should have been made + assert credential.request_count == 1