11from __future__ import annotations
22
3- import asyncio
43import inspect
54import sys
65import warnings
7- from collections .abc import Coroutine , Sequence
6+ from asyncio import CancelledError , Event , create_task
7+ from collections .abc import Awaitable , Coroutine , Sequence
88from logging import getLogger
99from types import FunctionType
1010from typing import (
@@ -95,9 +95,10 @@ def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:
9595 self .dispatch = dispatch
9696
9797
98+ _Coro = Coroutine [None , None , _Type ]
9899_EffectCleanFunc : TypeAlias = "Callable[[], None]"
99100_SyncEffectFunc : TypeAlias = "Callable[[], _EffectCleanFunc | None]"
100- _AsyncEffectFunc : TypeAlias = "Callable[[Effect], Coroutine[None, None, None]]"
101+ _AsyncEffectFunc : TypeAlias = "Callable[[Effect], _Coro[Awaitable[Any] | None]]"
101102_EffectFunc : TypeAlias = "_SyncEffectFunc | _AsyncEffectFunc"
102103
103104
@@ -152,8 +153,7 @@ async def start_effect() -> StopEffect:
152153 if effect_ref .current is not None :
153154 await effect_ref .current .stop ()
154155
155- effect = effect_ref .current = Effect ()
156- effect .task = asyncio .create_task (effect_func (effect ))
156+ effect = effect_ref .current = Effect (effect_func )
157157 await effect .started ()
158158
159159 return effect .stop
@@ -170,26 +170,37 @@ async def start_effect() -> StopEffect:
170170class Effect :
171171 """A context manager for running asynchronous effects."""
172172
173- task : asyncio .Task [Any ]
174- """The task that is running the effect."""
175-
176- def __init__ (self ) -> None :
177- self ._stop = asyncio .Event ()
178- self ._started = asyncio .Event ()
173+ def __init__ (self , effect_func : _AsyncEffectFunc ) -> None :
174+ self .task = create_task (effect_func (self ))
175+ self ._stop = Event ()
176+ self ._started = Event ()
177+ self ._stopped = Event ()
179178 self ._cancel_count = 0
180179
181180 async def stop (self ) -> None :
182181 """Signal the effect to stop."""
182+ if self ._stop .is_set ():
183+ await self ._stopped .wait ()
184+ return None
185+
183186 if self ._started .is_set ():
184187 self ._cancel_task ()
185188 self ._stop .set ()
186189 try :
187- await self .task
188- except asyncio . CancelledError :
190+ cleanup = await self .task
191+ except CancelledError :
189192 pass
190193 except Exception :
191194 logger .exception ("Error while stopping effect" )
192195
196+ if cleanup is not None :
197+ try :
198+ await cleanup
199+ except Exception :
200+ logger .exception ("Error while cleaning up effect" )
201+
202+ self ._stopped .set ()
203+
193204 async def started (self ) -> None :
194205 """Wait for the effect to start."""
195206 await self ._started .wait ()
@@ -205,6 +216,7 @@ async def __aenter__(self) -> Self:
205216
206217 if sys .version_info < (3 , 11 ): # nocov
207218 # Python<3.11 doesn't have Task.cancelling so we need to track it ourselves.
219+ # Task.uncancel is a no-op since there's no way to backport the behavior.
208220
209221 async def __aenter__ (self ) -> Self :
210222 cancel_count = 0
@@ -217,20 +229,22 @@ def new_cancel(*a, **kw) -> None:
217229
218230 self .task .cancel = new_cancel
219231 self .task .cancelling = lambda : cancel_count
232+ self .task .uncancel = lambda : None
220233
221234 return await self ._3_11__aenter__ ()
222235
223236 async def __aexit__ (self , exc_type : type [BaseException ], * exc : Any ) -> Any :
224- if exc_type is not None and not issubclass (exc_type , asyncio . CancelledError ):
237+ if exc_type is not None and not issubclass (exc_type , CancelledError ):
225238 # propagate non-cancellation exceptions
226239 return None
227240
228241 try :
229242 await self ._stop .wait ()
230- except asyncio . CancelledError :
243+ except CancelledError :
231244 if self .task .cancelling () > self ._cancel_count :
232245 # Task has been cancelled by something else - propagate it
233246 return None
247+ self .task .uncancel ()
234248
235249 return True
236250
0 commit comments