Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

### Other Changes

- Add lock-based concurrency control to asynchronous token requests. This addresses race conditions where multiple coroutines could simultaneously request tokens for the same scopes, resulting in redundant authentication requests. ([#43889](https://github.com/Azure/azure-sdk-for-python/pull/43889))

## 1.26.0b1 (2025-11-07)

### Features Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# ------------------------------------
import abc
import logging
import threading
import time
from typing import Any, Optional
from typing import Any, Optional, Dict, Type
from weakref import WeakValueDictionary

from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
from ..._internal import within_credential_chain
from .utils import get_running_async_lock_class

_LOGGER = logging.getLogger(__name__)

Expand All @@ -18,9 +21,39 @@ class GetTokenMixin(abc.ABC):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._last_request_time = 0

self._global_lock: Optional[Any] = None
self._global_lock_init_lock = threading.Lock()
self._active_locks: WeakValueDictionary[tuple, Any] = WeakValueDictionary()
self._lock_class_type: Optional[Type] = None

# https://github.com/python/mypy/issues/5887
super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore

@property
def _lock_class(self) -> Type:
if self._lock_class_type is None:
self._lock_class_type = get_running_async_lock_class()
return self._lock_class_type

async def _get_request_lock(self, lock_key: tuple) -> Any:
# Initialize global lock if needed, using threading.Lock for thread-safe initialization
if self._global_lock is None:
with self._global_lock_init_lock:
if self._global_lock is None:
self._global_lock = self._lock_class()

lock = self._active_locks.get(lock_key)
if lock is not None:
return lock

async with self._global_lock:
# Double-check in case another coroutine created it while we waited
lock = self._active_locks.get(lock_key)
if lock is None:
lock = self._lock_class()
self._active_locks[lock_key] = lock
return lock

@abc.abstractmethod
async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]:
"""Attempt to acquire an access token from a cache or by redeeming a refresh token.
Expand Down Expand Up @@ -132,19 +165,29 @@ async def _get_token_base(
token = await self._acquire_token_silently(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
if not token:
self._last_request_time = int(time.time())
token = await self._request_token(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
elif self._should_refresh(token):
try:
self._last_request_time = int(time.time())
token = await self._request_token(
if not token or self._should_refresh(token):
# Get the lock specific to this scope combination
lock_key = (tuple(sorted(scopes)), claims, tenant_id, enable_cae)
lock = await self._get_request_lock(lock_key)

async with lock:
# Double-check in case another coroutine refreshed the token while we waited for the lock
current_token = await self._acquire_token_silently(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
except Exception: # pylint:disable=broad-except
pass
if current_token and not self._should_refresh(current_token):
token = current_token
else:
try:
token = await self._request_token(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
except Exception: # pylint:disable=broad-except
self._last_request_time = int(time.time())
Comment on lines +181 to +186
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _last_request_time update should occur before calling _request_token (on success path) as well, not only in the exception handler. Currently, when _request_token succeeds, _last_request_time is never updated, which could bypass the retry delay check in _should_refresh. This inconsistency means the retry delay only applies after failures, not after successful token requests.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to modify this logic to only set the _last_request_time throttle when a _request_token call fails.

This still prevents rapid retries on failures, without allowing a successful request for one scope to block a necessary refresh for another (which would be the case if we set this on success as well).

# Only raise if we don't have a token to return
if not token:
raise

_LOGGER.log(
logging.DEBUG if within_credential_chain.get() else logging.INFO,
"%s.%s succeeded",
Expand All @@ -163,3 +206,19 @@ async def _get_token_base(
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
)
raise

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Remove the non-picklable entries
del state["_global_lock"]
del state["_lock_class_type"]
del state["_global_lock_init_lock"]
del state["_active_locks"]
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self._global_lock_init_lock = threading.Lock()
self._active_locks = WeakValueDictionary()
self._global_lock = None
self._lock_class_type = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import sys
from typing import Type


def get_running_async_lock_class() -> Type:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this into azure-core?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Core does already have a similar function, but it returns a Lock instance instead of a class/type. An instance is fine for most cases where only one lock needs to be instantiated, but in this particular scenario, several locks are instantiated. I only want to check the event loop once to get the class, so I think it's fine to have a separate function here for this purpose.

"""Get a lock class from the async library that the current context is running under.
:return: The running async library's Lock class.
:rtype: Type[Lock]
:raises RuntimeError: if the current context is not running under an async library.
"""

try:
import asyncio # pylint: disable=do-not-import-asyncio

# Check if we are running in an asyncio event loop.
asyncio.get_running_loop()
return asyncio.Lock
except RuntimeError as err:
# Otherwise, assume we are running in a trio event loop if it has already been imported.
if "trio" in sys.modules:
import trio # pylint: disable=networking-import-outside-azure-core-transport

return trio.Lock
raise RuntimeError("An asyncio or trio event loop is required.") from err
Loading