Skip to content

Commit c764476

Browse files
committed
[Identity] Add async token method concurrency control
Implements per-scope lock mechanism to prevent concurrent token requests for the same token request argument combination, reducing unnecessary network calls and improving performance in high-concurrency async scenarios. Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 90a9226 commit c764476

File tree

4 files changed

+492
-15
lines changed

4 files changed

+492
-15
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
### Other Changes
1212

13+
- Add lock-based concurrency control to asynchronous token requests. This addresses race conditions where multiple coroutines could simultaneously request tokens for the same scopes, resulting in redundant authentication requests. ([#43889](https://github.com/Azure/azure-sdk-for-python/pull/43889))
14+
1315
## 1.26.0b1 (2025-11-07)
1416

1517
### Features Added

sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
# ------------------------------------
55
import abc
66
import logging
7+
import threading
78
import time
8-
from typing import Any, Optional
9+
from typing import Any, Optional, Dict, Type
10+
from weakref import WeakValueDictionary
911

1012
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
1113
from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
1214
from ..._internal import within_credential_chain
15+
from .utils import get_running_async_lock_class
1316

1417
_LOGGER = logging.getLogger(__name__)
1518

@@ -18,9 +21,39 @@ class GetTokenMixin(abc.ABC):
1821
def __init__(self, *args: Any, **kwargs: Any) -> None:
1922
self._last_request_time = 0
2023

24+
self._global_lock: Optional[Any] = None
25+
self._global_lock_init_lock = threading.Lock()
26+
self._active_locks: WeakValueDictionary[tuple, Any] = WeakValueDictionary()
27+
self._lock_class_type: Optional[Type] = None
28+
2129
# https://github.com/python/mypy/issues/5887
2230
super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore
2331

32+
@property
33+
def _lock_class(self) -> Type:
34+
if self._lock_class_type is None:
35+
self._lock_class_type = get_running_async_lock_class()
36+
return self._lock_class_type
37+
38+
async def _get_request_lock(self, lock_key: tuple) -> Any:
39+
# Initialize global lock if needed, using threading.Lock for thread-safe initialization
40+
if self._global_lock is None:
41+
with self._global_lock_init_lock:
42+
if self._global_lock is None:
43+
self._global_lock = self._lock_class()
44+
45+
lock = self._active_locks.get(lock_key)
46+
if lock is not None:
47+
return lock
48+
49+
async with self._global_lock:
50+
# Double-check in case another coroutine created it while we waited
51+
lock = self._active_locks.get(lock_key)
52+
if lock is None:
53+
lock = self._lock_class()
54+
self._active_locks[lock_key] = lock
55+
return lock
56+
2457
@abc.abstractmethod
2558
async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]:
2659
"""Attempt to acquire an access token from a cache or by redeeming a refresh token.
@@ -132,19 +165,29 @@ async def _get_token_base(
132165
token = await self._acquire_token_silently(
133166
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
134167
)
135-
if not token:
136-
self._last_request_time = int(time.time())
137-
token = await self._request_token(
138-
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
139-
)
140-
elif self._should_refresh(token):
141-
try:
142-
self._last_request_time = int(time.time())
143-
token = await self._request_token(
168+
if not token or self._should_refresh(token):
169+
# Get the lock specific to this scope combination
170+
lock_key = (tuple(sorted(scopes)), claims, tenant_id, enable_cae)
171+
lock = await self._get_request_lock(lock_key)
172+
173+
async with lock:
174+
# Double-check in case another coroutine refreshed the token while we waited for the lock
175+
current_token = await self._acquire_token_silently(
144176
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
145177
)
146-
except Exception: # pylint:disable=broad-except
147-
pass
178+
if current_token and not self._should_refresh(current_token):
179+
token = current_token
180+
else:
181+
try:
182+
token = await self._request_token(
183+
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
184+
)
185+
except Exception: # pylint:disable=broad-except
186+
self._last_request_time = int(time.time())
187+
# Only raise if we don't have a token to return
188+
if not token:
189+
raise
190+
148191
_LOGGER.log(
149192
logging.DEBUG if within_credential_chain.get() else logging.INFO,
150193
"%s.%s succeeded",
@@ -163,3 +206,19 @@ async def _get_token_base(
163206
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
164207
)
165208
raise
209+
210+
def __getstate__(self) -> Dict[str, Any]:
211+
state = self.__dict__.copy()
212+
# Remove the non-picklable entries
213+
del state["_global_lock"]
214+
del state["_lock_class_type"]
215+
del state["_global_lock_init_lock"]
216+
del state["_active_locks"]
217+
return state
218+
219+
def __setstate__(self, state: Dict[str, Any]) -> None:
220+
self.__dict__.update(state)
221+
self._global_lock_init_lock = threading.Lock()
222+
self._active_locks = WeakValueDictionary()
223+
self._global_lock = None
224+
self._lock_class_type = None
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import sys
6+
from typing import Type
7+
8+
9+
def get_running_async_lock_class() -> Type:
10+
"""Get a lock class from the async library that the current context is running under.
11+
12+
:return: The running async library's Lock class.
13+
:rtype: Type[Lock]
14+
:raises RuntimeError: if the current context is not running under an async library.
15+
"""
16+
17+
try:
18+
import asyncio # pylint: disable=do-not-import-asyncio
19+
20+
# Check if we are running in an asyncio event loop.
21+
asyncio.get_running_loop()
22+
return asyncio.Lock
23+
except RuntimeError as err:
24+
# Otherwise, assume we are running in a trio event loop if it has already been imported.
25+
if "trio" in sys.modules:
26+
import trio # pylint: disable=networking-import-outside-azure-core-transport
27+
28+
return trio.Lock
29+
raise RuntimeError("An asyncio or trio event loop is required.") from err

0 commit comments

Comments
 (0)