22
33from __future__ import annotations
44
5- import dataclasses
65import inspect
76import queue
8- import threading
97import weakref
10- from asyncio import AbstractEventLoop , get_event_loop , iscoroutinefunction
8+ from asyncio import get_event_loop , iscoroutine
119from collections import defaultdict
1210from inspect import signature
1311from threading import Lock
14- from types import NoneType
1512from typing import Any , Callable , Coroutine , Generic , cast
1613
17- from immutable import Immutable , is_immutable
18-
1914from redux .autorun import Autorun
2015from redux .basic_types import (
2116 Action ,
3934 SelectorOutput ,
4035 SnapshotAtom ,
4136 State ,
37+ TaskCreator ,
38+ TaskCreatorCallback ,
4239 is_complete_reducer_result ,
4340 is_state_reducer_result ,
4441)
42+ from redux .serialization_mixin import SerializationMixin
43+ from redux .side_effect_runner import SideEffectRunnerThread
4544
4645
47- class _SideEffectRunnerThread (threading .Thread , Generic [Event ]):
48- def __init__ (
49- self : _SideEffectRunnerThread [Event ],
50- * ,
51- task_queue : queue .Queue [tuple [EventHandler [Event ], Event ] | None ],
52- async_loop : AbstractEventLoop ,
53- ) -> None :
54- super ().__init__ ()
55- self .task_queue = task_queue
56- self .async_loop = async_loop
57-
58- def create_task (self : _SideEffectRunnerThread [Event ], coro : Coroutine ) -> None :
59- self .async_loop .call_soon_threadsafe (lambda : self .async_loop .create_task (coro ))
60-
61- def run (self : _SideEffectRunnerThread [Event ]) -> None :
62- while True :
63- task = self .task_queue .get ()
64- if task is None :
65- self .task_queue .task_done ()
66- break
67-
68- try :
69- event_handler , event = task
70- if len (signature (event_handler ).parameters ) == 1 :
71- result = cast (Callable [[Event ], Any ], event_handler )(event )
72- else :
73- result = cast (Callable [[], Any ], event_handler )()
74- if iscoroutinefunction (event_handler ):
75- self .create_task (result )
76- finally :
77- self .task_queue .task_done ()
46+ def _default_task_creator (
47+ coro : Coroutine ,
48+ callback : TaskCreatorCallback | None = None ,
49+ ) -> None :
50+ result = get_event_loop ().create_task (coro )
51+ if callback :
52+ callback (result )
7853
7954
80- class Store (Generic [State , Action , Event ]):
55+ class Store (Generic [State , Action , Event ], SerializationMixin ):
8156 """Redux store for managing state and side effects."""
8257
8358 def __init__ (
@@ -88,7 +63,9 @@ def __init__(
8863 """Create a new store."""
8964 self .store_options = options or CreateStoreOptions ()
9065 self .reducer = reducer
91- self ._async_loop = self .store_options .async_loop or get_event_loop ()
66+ self ._create_task : TaskCreator = (
67+ self .store_options .task_creator or _default_task_creator
68+ )
9269
9370 self ._state : State | None = None
9471 self ._listeners : set [
@@ -110,14 +87,14 @@ def __init__(
11087 self ._event_handlers_queue = queue .Queue [
11188 tuple [EventHandler [Event ], Event ] | None
11289 ]()
113- workers = [
114- _SideEffectRunnerThread (
90+ self . _workers = [
91+ SideEffectRunnerThread (
11592 task_queue = self ._event_handlers_queue ,
116- async_loop = self ._async_loop ,
93+ task_creator = self ._create_task ,
11794 )
11895 for _ in range (self .store_options .threads )
11996 ]
120- for worker in workers :
97+ for worker in self . _workers :
12198 worker .start ()
12299
123100 self ._is_running = Lock ()
@@ -158,8 +135,8 @@ def _run_actions(self: Store[State, Action, Event]) -> None:
158135 else :
159136 listener = listener_
160137 result = listener (self ._state )
161- if iscoroutinefunction ( listener ):
162- self ._async_loop . create_task (result )
138+ if iscoroutine ( result ):
139+ self ._create_task (result )
163140
164141 def _run_event_handlers (self : Store [State , Action , Event ]) -> None :
165142 event = self ._events .pop (0 )
@@ -175,10 +152,13 @@ def _run_event_handlers(self: Store[State, Action, Event]) -> None:
175152 event_handler = event_handler_
176153 if not options .immediate_run :
177154 self ._event_handlers_queue .put ((event_handler , event ))
178- elif len (signature (event_handler ).parameters ) == 1 :
179- cast (Callable [[Event ], Any ], event_handler )(event )
180155 else :
181- cast (Callable [[], Any ], event_handler )()
156+ if len (signature (event_handler ).parameters ) == 1 :
157+ result = cast (Callable [[Event ], Any ], event_handler )(event )
158+ else :
159+ result = cast (Callable [[], Any ], event_handler )()
160+ if iscoroutine (result ):
161+ self ._create_task (result )
182162
183163 def run (self : Store [State , Action , Event ]) -> None :
184164 """Run the store."""
@@ -189,6 +169,12 @@ def run(self: Store[State, Action, Event]) -> None:
189169
190170 if len (self ._events ) > 0 :
191171 self ._run_event_handlers ()
172+ if not any (i .is_alive () for i in self ._workers ):
173+ for worker in self ._workers :
174+ worker .join ()
175+ self ._workers .clear ()
176+ self ._listeners .clear ()
177+ self ._event_handlers .clear ()
192178
193179 def dispatch (
194180 self : Store [State , Action , Event ],
@@ -258,15 +244,15 @@ def subscribe_event(
258244 self ._event_handlers [cast (type [Event ], event_type )].add (
259245 (handler_ref , subscription_options ),
260246 )
261- return lambda : self ._event_handlers [cast (type [Event ], event_type )].discard (
262- (handler_ref , subscription_options ),
263- )
264247
265- def _handle_finish_event (
266- self : Store [State , Action , Event ],
267- finish_event : Event ,
268- ) -> None :
269- _ = finish_event
248+ def unsubscribe () -> None :
249+ self ._event_handlers [cast (type [Event ], event_type )].discard (
250+ (handler_ref , subscription_options ),
251+ )
252+
253+ return unsubscribe
254+
255+ def _handle_finish_event (self : Store [State , Action , Event ]) -> None :
270256 for _ in range (self .store_options .threads ):
271257 self ._event_handlers_queue .put (None )
272258
@@ -301,28 +287,3 @@ def decorator(
301287 def snapshot (self : Store [State , Action , Event ]) -> SnapshotAtom :
302288 """Return a snapshot of the current state of the store."""
303289 return self .serialize_value (self ._state )
304-
305- @classmethod
306- def serialize_value (cls : type [Store ], obj : object | type ) -> SnapshotAtom :
307- """Serialize a value to a snapshot atom."""
308- if isinstance (obj , (int , float , str , bool , NoneType )):
309- return obj
310- if callable (obj ):
311- return cls .serialize_value (obj ())
312- if isinstance (obj , (list , tuple )):
313- return [cls .serialize_value (i ) for i in obj ]
314- if is_immutable (obj ):
315- return cls ._serialize_dataclass_to_dict (obj )
316- msg = f'Unable to serialize object with type `{ type (obj )} `.'
317- raise TypeError (msg )
318-
319- @classmethod
320- def _serialize_dataclass_to_dict (
321- cls : type [Store ],
322- obj : Immutable ,
323- ) -> dict [str , Any ]:
324- result = {}
325- for field in dataclasses .fields (obj ):
326- value = cls .serialize_value (getattr (obj , field .name ))
327- result [field .name ] = value
328- return result
0 commit comments