11from collections .abc import Awaitable , Callable , Generator
22from functools import wraps
3- from typing import NewType , ParamSpec , Protocol , TypeVar , cast , final
4-
3+ from typing import Literal , NewType , ParamSpec , Protocol , TypeVar , cast , final
4+ # Always import asyncio
5+ import asyncio
56
67class AsyncLock (Protocol ):
78 """A protocol for an asynchronous lock."""
@@ -13,26 +14,17 @@ async def __aenter__(self) -> None: ...
1314 async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None : ...
1415
1516
16- # Import both libraries if available
17- import asyncio # noqa: WPS433
18- from enum import Enum , auto
19-
20-
21- class AsyncContext (Enum ):
22- """Enum representing different async context types."""
23-
24- ASYNCIO = auto ()
25- TRIO = auto ()
26- UNKNOWN = auto ()
17+ # Define context types as literals
18+ AsyncContext = Literal ["asyncio" , "trio" , "unknown" ]
2719
2820
2921# Check for anyio and trio availability
3022try :
31- import anyio # noqa: WPS433
23+ import anyio # pragma: no qa
3224
3325 has_anyio = True
3426 try :
35- import trio # noqa: WPS433
27+ import trio # pragma: no qa
3628
3729 has_trio = True
3830 except ImportError : # pragma: no cover
@@ -42,27 +34,40 @@ class AsyncContext(Enum):
4234 has_trio = False
4335
4436
37+ def _is_in_trio_context () -> bool :
38+ """Check if we're in a trio context.
39+
40+ Returns:
41+ bool: True if we're in a trio context
42+ """
43+ if not has_trio :
44+ return False
45+
46+ # Import trio here since we already checked it's available
47+ import trio
48+
49+ try :
50+ # Will raise RuntimeError if not in trio context
51+ trio .lowlevel .current_task ()
52+ except (RuntimeError , AttributeError ):
53+ return False
54+ return True
55+
56+
4557def detect_async_context () -> AsyncContext :
4658 """Detect which async context we're currently running in.
4759
4860 Returns:
4961 AsyncContext: The current async context type
5062 """
5163 if not has_anyio : # pragma: no cover
52- return AsyncContext .ASYNCIO
53-
54- if has_trio :
55- try :
56- # Check if we're in a trio context
57- # Will raise RuntimeError if not in trio context
58- trio .lowlevel .current_task ()
59- return AsyncContext .TRIO
60- except (RuntimeError , AttributeError ):
61- # Not in a trio context or trio API changed
62- pass
64+ return "asyncio"
65+
66+ if _is_in_trio_context ():
67+ return "trio"
6368
6469 # Default to asyncio
65- return AsyncContext . ASYNCIO
70+ return "asyncio"
6671
6772
6873_ValueType = TypeVar ('_ValueType' )
@@ -121,7 +126,7 @@ def __init__(self, coro: Awaitable[_ValueType]) -> None:
121126 """We need just an awaitable to work with."""
122127 self ._coro = coro
123128 self ._cache : _ValueType | _Sentinel = _sentinel
124- self ._lock = None # Will be created lazily based on the backend
129+ self ._lock : AsyncLock | None = None # Will be created lazily based on the backend
125130
126131 def __await__ (self ) -> Generator [None , None , _ValueType ]:
127132 """
@@ -171,10 +176,11 @@ def _create_lock(self) -> AsyncLock:
171176 """Create the appropriate lock based on the current async context."""
172177 context = detect_async_context ()
173178
174- if context == AsyncContext .TRIO and has_anyio :
179+ if context == "trio" and has_anyio :
180+ import anyio
175181 return anyio .Lock ()
176182
177- # For ASYNCIO or UNKNOWN contexts
183+ # For asyncio or unknown contexts
178184 return asyncio .Lock ()
179185
180186 async def _awaitable (self ) -> _ValueType :
@@ -222,4 +228,4 @@ def decorator(
222228 ) -> _AwaitableT :
223229 return ReAwaitable (coro (* args , ** kwargs )) # type: ignore[return-value]
224230
225- return decorator
231+ return decorator
0 commit comments