Skip to content

Commit 1c2203d

Browse files
committed
[Identity] Add async token method concurrency control
Implements per-scope lock mechanism to prevent concurrent token requests for the same token request argument combination, reducing unnecessary network calls and improving performance in high-concurrency async scenarios. Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 90a9226 commit 1c2203d

File tree

4 files changed

+484
-15
lines changed

4 files changed

+484
-15
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
### Other Changes
1212

13+
- Add lock-based concurrency control to asynchronous token requests. This addresses race conditions where multiple coroutines could simultaneously request tokens for the same scopes, resulting in redundant authentication requests. ([#43889](https://github.com/Azure/azure-sdk-for-python/pull/43889))
14+
1315
## 1.26.0b1 (2025-11-07)
1416

1517
### Features Added

sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import abc
66
import logging
77
import time
8-
from typing import Any, Optional
8+
from typing import Any, Optional, MutableMapping
9+
from weakref import WeakValueDictionary
910

1011
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
1112
from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
1213
from ..._internal import within_credential_chain
14+
from .utils import get_running_async_lock_class
1315

1416
_LOGGER = logging.getLogger(__name__)
1517

@@ -18,9 +20,35 @@ class GetTokenMixin(abc.ABC):
1820
def __init__(self, *args: Any, **kwargs: Any) -> None:
1921
self._last_request_time = 0
2022

23+
self._global_lock = None
24+
self._active_locks: MutableMapping = WeakValueDictionary()
25+
self.__lock_class = None
26+
2127
# https://github.com/python/mypy/issues/5887
2228
super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore
2329

30+
@property
31+
def _lock_class(self):
32+
if self.__lock_class is None:
33+
self.__lock_class = get_running_async_lock_class()
34+
return self.__lock_class
35+
36+
async def _get_request_lock(self, scope_key):
37+
if self._global_lock is None:
38+
self._global_lock = self._lock_class()
39+
40+
lock = self._active_locks.get(scope_key)
41+
if lock is not None:
42+
return lock
43+
44+
async with self._global_lock:
45+
# Double-check in case another coroutine created it while we waited
46+
lock = self._active_locks.get(scope_key)
47+
if lock is None:
48+
lock = self._lock_class()
49+
self._active_locks[scope_key] = lock
50+
return lock
51+
2452
@abc.abstractmethod
2553
async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]:
2654
"""Attempt to acquire an access token from a cache or by redeeming a refresh token.
@@ -132,19 +160,29 @@ async def _get_token_base(
132160
token = await self._acquire_token_silently(
133161
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
134162
)
135-
if not token:
136-
self._last_request_time = int(time.time())
137-
token = await self._request_token(
138-
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
139-
)
140-
elif self._should_refresh(token):
141-
try:
142-
self._last_request_time = int(time.time())
143-
token = await self._request_token(
163+
if not token or self._should_refresh(token):
164+
# Get the lock specific to this scope combination
165+
lock_key = (tuple(sorted(scopes)), claims, tenant_id, enable_cae)
166+
lock = await self._get_request_lock(lock_key)
167+
168+
async with lock:
169+
# Double-check in case another coroutine refreshed the token while we waited for the lock
170+
current_token = await self._acquire_token_silently(
144171
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
145172
)
146-
except Exception: # pylint:disable=broad-except
147-
pass
173+
if current_token and not self._should_refresh(current_token):
174+
token = current_token
175+
else:
176+
try:
177+
token = await self._request_token(
178+
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
179+
)
180+
except Exception: # pylint:disable=broad-except
181+
self._last_request_time = int(time.time())
182+
# Only raise if we don't have a token to return
183+
if not token:
184+
raise
185+
148186
_LOGGER.log(
149187
logging.DEBUG if within_credential_chain.get() else logging.INFO,
150188
"%s.%s succeeded",
@@ -163,3 +201,16 @@ async def _get_token_base(
163201
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
164202
)
165203
raise
204+
205+
def __getstate__(self) -> dict:
206+
state = self.__dict__.copy()
207+
# Remove the non-picklable entries
208+
state["_global_lock"] = None
209+
state["_active_locks"] = {}
210+
state["__lock_class"] = None
211+
212+
return state
213+
214+
def __setstate__(self, state: dict) -> None:
215+
self.__dict__.update(state)
216+
self._active_locks = WeakValueDictionary()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import sys
6+
from typing import Type
7+
8+
9+
def get_running_async_lock_class() -> Type:
10+
"""Get a lock class from the async library that the current context is running under.
11+
12+
:return: The running async library's Lock class.
13+
:rtype: Type[Lock]
14+
:raises RuntimeError: if the current context is not running under an async library.
15+
"""
16+
17+
try:
18+
import asyncio # pylint: disable=do-not-import-asyncio
19+
20+
# Check if we are running in an asyncio event loop.
21+
asyncio.get_running_loop()
22+
return asyncio.Lock
23+
except RuntimeError as err:
24+
# Otherwise, assume we are running in a trio event loop if it has already been imported.
25+
if "trio" in sys.modules:
26+
import trio # pylint: disable=networking-import-outside-azure-core-transport
27+
28+
return trio.Lock
29+
raise RuntimeError("An asyncio or trio event loop is required.") from err

0 commit comments

Comments
 (0)