22
33import asyncio
44import inspect
5+ import sys
56import warnings
67from collections .abc import Coroutine , Sequence
78from logging import getLogger
1718 overload ,
1819)
1920
20- from typing_extensions import TypeAlias
21+ from typing_extensions import Self , TypeAlias
2122
22- from reactpy .config import REACTPY_DEBUG_MODE , REACTPY_EFFECT_DEFAULT_STOP_TIMEOUT
23- from reactpy .core ._life_cycle_hook import EffectInfo , current_hook
23+ from reactpy .config import REACTPY_DEBUG_MODE
24+ from reactpy .core ._life_cycle_hook import StopEffect , current_hook
2425from reactpy .core .types import Context , Key , State , VdomDict
2526from reactpy .utils import Ref
2627
@@ -96,15 +97,14 @@ def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:
9697
9798_EffectCleanFunc : TypeAlias = "Callable[[], None]"
9899_SyncEffectFunc : TypeAlias = "Callable[[], _EffectCleanFunc | None]"
99- _AsyncEffectFunc : TypeAlias = "Callable[[asyncio.Event ], Coroutine[None, None, None]]"
100+ _AsyncEffectFunc : TypeAlias = "Callable[[Effect ], Coroutine[None, None, None]]"
100101_EffectFunc : TypeAlias = "_SyncEffectFunc | _AsyncEffectFunc"
101102
102103
103104@overload
104105def use_effect (
105106 function : None = None ,
106107 dependencies : Sequence [Any ] | ellipsis | None = ...,
107- stop_timeout : float = ...,
108108) -> Callable [[_EffectFunc ], None ]:
109109 ...
110110
@@ -113,15 +113,13 @@ def use_effect(
113113def use_effect (
114114 function : _EffectFunc ,
115115 dependencies : Sequence [Any ] | ellipsis | None = ...,
116- stop_timeout : float = ...,
117116) -> None :
118117 ...
119118
120119
121120def use_effect (
122121 function : _EffectFunc | None = None ,
123122 dependencies : Sequence [Any ] | ellipsis | None = ...,
124- stop_timeout : float = REACTPY_EFFECT_DEFAULT_STOP_TIMEOUT .current ,
125123) -> Callable [[_EffectFunc ], None ] | None :
126124 """See the full :ref:`Use Effect` docs for details
127125
@@ -145,22 +143,21 @@ def use_effect(
145143 hook = current_hook ()
146144 dependencies = _try_to_infer_closure_values (function , dependencies )
147145 memoize = use_memo (dependencies = dependencies )
148- effect_info : Ref [EffectInfo | None ] = use_ref (None )
146+ effect_ref : Ref [Effect | None ] = use_ref (None )
149147
150148 def add_effect (function : _EffectFunc ) -> None :
151- effect = _cast_async_effect (function )
149+ effect_func = _cast_async_effect (function )
152150
153- async def create_effect_task () -> EffectInfo :
154- if effect_info .current is not None :
155- last_effect_info = effect_info .current
156- await last_effect_info .signal_stop (stop_timeout )
151+ async def start_effect () -> StopEffect :
152+ if effect_ref .current is not None :
153+ await effect_ref .current .stop ()
157154
158- stop = asyncio .Event ()
159- info = EffectInfo (asyncio .create_task (effect (stop )), stop )
160- effect_info .current = info
161- return info
155+ effect = effect_ref .current = Effect ()
156+ effect .task = asyncio .create_task (effect_func (effect ))
162157
163- return memoize (lambda : hook .add_effect (create_effect_task ))
158+ return effect .stop
159+
160+ return memoize (lambda : hook .add_effect (start_effect ))
164161
165162 if function is not None :
166163 add_effect (function )
@@ -169,47 +166,114 @@ async def create_effect_task() -> EffectInfo:
169166 return add_effect
170167
171168
169+ class Effect :
170+ """A context manager for running asynchronous effects."""
171+
172+ task : asyncio .Task [Any ]
173+ """The task that is running the effect."""
174+
175+ def __init__ (self ) -> None :
176+ self ._stop = asyncio .Event ()
177+ self ._started = False
178+ self ._cancel_count = 0
179+
180+ async def stop (self ) -> None :
181+ """Signal the effect to stop."""
182+ if self ._started :
183+ self ._cancel_task ()
184+ self ._stop .set ()
185+ try :
186+ await self .task
187+ except asyncio .CancelledError :
188+ pass
189+ except Exception :
190+ logger .exception ("Error while stopping effect" )
191+
192+ async def __aenter__ (self ) -> Self :
193+ self ._started = True
194+ self ._cancel_count = self .task .cancelling ()
195+ if self ._stop .is_set ():
196+ self ._cancel_task ()
197+ return self
198+
199+ if sys .version_info < (3 , 11 ): # nocov
200+ # Python<3.11 doesn't have Task.cancelling so we need to track it ourselves.
201+
202+ _3_11__aenter__ = __aenter__
203+
204+ async def __aenter__ (self ) -> Self :
205+ cancel_count = 0
206+ old_cancel = self .task .cancel
207+
208+ def new_cancel (* a , ** kw ) -> None :
209+ nonlocal cancel_count
210+ cancel_count += 1
211+ return old_cancel (* a , ** kw )
212+
213+ self .task .cancel = new_cancel
214+ self .task .cancelling = lambda : cancel_count
215+
216+ return await _3_11__aenter__ (self ) # noqa: F821
217+
218+ async def __aexit__ (self , exc_type : type [BaseException ], * exc : Any ) -> Any :
219+ if exc_type is not None and not issubclass (exc_type , asyncio .CancelledError ):
220+ # propagate non-cancellation exceptions
221+ return None
222+
223+ try :
224+ await self ._stop .wait ()
225+ except asyncio .CancelledError :
226+ if self .task .cancelling () > self ._cancel_count :
227+ # Task has been cancelled by something else - propagate it
228+ return None
229+
230+ return True
231+
232+ def _cancel_task (self ) -> None :
233+ self .task .cancel ()
234+ self ._cancel_count += 1
235+
236+
172237def _cast_async_effect (function : Callable [..., Any ]) -> _AsyncEffectFunc :
173238 if inspect .iscoroutinefunction (function ):
174239 if len (inspect .signature (function ).parameters ):
175240 return function
176241
177242 warnings .warn (
178- ' Async effect functions should accept a "stop" asyncio.Event as their '
243+ " Async effect functions should accept an Effect context manager as their "
179244 "first argument. This will be required in a future version of ReactPy." ,
180245 stacklevel = 3 ,
181246 )
182247
183- async def wrapper (stop : asyncio .Event ) -> None :
184- task = asyncio .create_task (function ())
185- await stop .wait ()
186- if not task .cancel ():
248+ async def wrapper (effect : Effect ) -> None :
249+ cleanup = None
250+ async with effect :
187251 try :
188- cleanup = await task
252+ cleanup = await function ()
189253 except Exception :
190254 logger .exception ("Error while applying effect" )
191- return
192- if cleanup is not None :
193- try :
194- cleanup ()
195- except Exception :
196- logger .exception ("Error while cleaning up effect" )
255+ if cleanup is not None :
256+ try :
257+ cleanup ()
258+ except Exception :
259+ logger .exception ("Error while cleaning up effect" )
197260
198261 return wrapper
199262 else :
200263
201- async def wrapper (stop : asyncio .Event ) -> None :
202- try :
203- cleanup = function ()
204- except Exception :
205- logger .exception ("Error while applying effect" )
206- return
207- await stop .wait ()
208- try :
209- if cleanup is not None :
264+ async def wrapper (effect : Effect ) -> None :
265+ cleanup = None
266+ async with effect :
267+ try :
268+ cleanup = function ()
269+ except Exception :
270+ logger .exception ("Error while applying effect" )
271+
272+ if cleanup is not None :
273+ try :
210274 cleanup ()
211- except Exception :
212- logger .exception ("Error while cleaning up effect" )
275+ except Exception :
276+ logger .exception ("Error while cleaning up effect" )
213277
214278 return wrapper
215279
0 commit comments