@@ -13,16 +13,52 @@ async def __aenter__(self) -> None: ...
1313 async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None : ...
1414
1515
16- # Try to use anyio.Lock, fall back to asyncio.Lock
17- # Note: anyio is required for proper trio support
16+ # Import both libraries if available
17+ import asyncio # noqa: WPS433
18+ from enum import Enum , auto
19+
20+ class AsyncContext (Enum ):
21+ """Enum representing different async context types."""
22+
23+ ASYNCIO = auto ()
24+ TRIO = auto ()
25+ UNKNOWN = auto ()
26+
27+ # Check for anyio and trio availability
1828try :
1929 import anyio # noqa: WPS433
30+ has_anyio = True
31+ try :
32+ import trio # noqa: WPS433
33+ has_trio = True
34+ except ImportError : # pragma: no cover
35+ has_trio = False
2036except ImportError : # pragma: no cover
21- import asyncio # noqa: WPS433
37+ has_anyio = False
38+ has_trio = False
2239
23- Lock : type [AsyncLock ] = asyncio .Lock
24- else :
25- Lock = cast (type [AsyncLock ], anyio .Lock )
40+
41+ def detect_async_context () -> AsyncContext :
42+ """Detect which async context we're currently running in.
43+
44+ Returns:
45+ AsyncContext: The current async context type
46+ """
47+ if not has_anyio : # pragma: no cover
48+ return AsyncContext .ASYNCIO
49+
50+ if has_trio :
51+ try :
52+ # Check if we're in a trio context
53+ # Will raise RuntimeError if not in trio context
54+ trio .lowlevel .current_task ()
55+ return AsyncContext .TRIO
56+ except (RuntimeError , AttributeError ):
57+ # Not in a trio context or trio API changed
58+ pass
59+
60+ # Default to asyncio
61+ return AsyncContext .ASYNCIO
2662
2763_ValueType = TypeVar ('_ValueType' )
2864_AwaitableT = TypeVar ('_AwaitableT' , bound = Awaitable )
@@ -78,9 +114,9 @@ class ReAwaitable:
78114
79115 def __init__ (self , coro : Awaitable [_ValueType ]) -> None :
80116 """We need just an awaitable to work with."""
81- self ._lock = Lock ()
82117 self ._coro = coro
83118 self ._cache : _ValueType | _Sentinel = _sentinel
119+ self ._lock = None # Will be created lazily based on the backend
84120
85121 def __await__ (self ) -> Generator [None , None , _ValueType ]:
86122 """
@@ -126,8 +162,22 @@ def __repr__(self) -> str:
126162 """
127163 return repr (self ._coro )
128164
165+ def _create_lock (self ) -> AsyncLock :
166+ """Create the appropriate lock based on the current async context."""
167+ context = detect_async_context ()
168+
169+ if context == AsyncContext .TRIO and has_anyio :
170+ return anyio .Lock ()
171+
172+ # For ASYNCIO or UNKNOWN contexts
173+ return asyncio .Lock ()
174+
129175 async def _awaitable (self ) -> _ValueType :
130176 """Caches the once awaited value forever."""
177+ # Create the lock if it doesn't exist
178+ if self ._lock is None :
179+ self ._lock = self ._create_lock ()
180+
131181 async with self ._lock :
132182 if self ._cache is _sentinel :
133183 self ._cache = await self ._coro
0 commit comments