Skip to content

Commit aa96baa

Browse files
committed
refactor: side effect runners always run the side effect in the event loop provided to them regardless of the return value of the side effect being a coroutine or not, this is because even if the side effect is not a coroutine, it might still use async features internally
1 parent 86f07ae commit aa96baa

File tree

7 files changed

+105
-94
lines changed

7 files changed

+105
-94
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Upcoming
44

55
- refactor: provide correct signature for the autorun instance based on the function it decorates
6+
- refactor: side effect runners always run the side effect in the event loop provided to them regardless of the return value of the side effect being a coroutine or not, this is because even if the side effect is not a coroutine, it might still use async features internally
67

78
## Version 0.18.3
89

redux/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ def __init__(
9393
tuple[EventHandler[Event], Event] | None
9494
]()
9595
self._workers = [
96-
SideEffectRunnerThread(task_queue=self._event_handlers_queue)
96+
SideEffectRunnerThread(
97+
task_queue=self._event_handlers_queue,
98+
create_task=self._create_task,
99+
)
97100
for _ in range(self.store_options.threads)
98101
]
99102
for worker in self._workers:

redux/side_effect_runner.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import asyncio
66
import contextlib
7+
import inspect
78
import threading
89
import weakref
910
from asyncio import Handle, iscoroutine
1011
from collections.abc import Callable
11-
from inspect import signature
1212
from typing import TYPE_CHECKING, Any, Generic, cast
1313

14-
from redux.basic_types import Event, EventHandler
14+
from redux.basic_types import Event, EventHandler, TaskCreator
1515

1616
if TYPE_CHECKING:
1717
import queue
@@ -24,15 +24,14 @@ def __init__(
2424
self: SideEffectRunnerThread,
2525
*,
2626
task_queue: queue.Queue[tuple[EventHandler[Event], Event] | None],
27+
create_task: TaskCreator | None,
2728
) -> None:
2829
"""Initialize the side effect runner thread."""
2930
super().__init__()
3031
self.task_queue = task_queue
3132
self.loop = asyncio.get_event_loop()
3233
self._handles: set[Handle] = set()
33-
self.create_task = lambda coro: self._handles.add(
34-
self.loop.call_soon_threadsafe(self.loop.create_task, coro),
35-
)
34+
self.create_task = create_task
3635

3736
def run(self: SideEffectRunnerThread[Event]) -> None:
3837
"""Run the side effect runner thread."""
@@ -51,12 +50,27 @@ def run(self: SideEffectRunnerThread[Event]) -> None:
5150
event_handler = event_handler_
5251
parameters = 1
5352
with contextlib.suppress(Exception):
54-
parameters = len(signature(event_handler).parameters)
55-
if parameters == 1:
56-
result = cast(Callable[[Event], Any], event_handler)(event)
57-
else:
58-
result = cast(Callable[[], Any], event_handler)()
59-
if iscoroutine(result):
60-
self.create_task(result)
53+
parameters = len(inspect.signature(event_handler).parameters)
54+
55+
if self.create_task:
56+
57+
async def _(
58+
event_handler: EventHandler[Event],
59+
event: Event,
60+
parameters: int,
61+
) -> None:
62+
if parameters == 1:
63+
result = cast(Callable[[Event], Any], event_handler)(event)
64+
else:
65+
result = cast(Callable[[], Any], event_handler)()
66+
if iscoroutine(result):
67+
await result
68+
69+
self.create_task(_(event_handler, event, parameters))
70+
else: # noqa: PLR5501
71+
if parameters == 1:
72+
cast(Callable[[Event], Any], event_handler)(event)
73+
else:
74+
cast(Callable[[], Any], event_handler)()
6175
finally:
6276
self.task_queue.task_done()

redux_pytest/fixtures/event_loop.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,37 @@
88
import pytest
99

1010
if TYPE_CHECKING:
11-
from collections.abc import Coroutine
11+
from collections.abc import Callable, Coroutine
12+
13+
from redux.basic_types import TaskCreatorCallback
1214

1315

1416
class LoopThread(threading.Thread):
1517
def __init__(self: LoopThread) -> None:
1618
super().__init__()
1719
self.loop = asyncio.new_event_loop()
18-
asyncio.set_event_loop(self.loop)
1920

2021
def run(self: LoopThread) -> None:
2122
self.loop.run_forever()
2223

2324
def stop(self: LoopThread) -> None:
24-
asyncio.set_event_loop(None)
2525
self.loop.call_soon_threadsafe(self.loop.stop)
2626

27-
def create_task(self: LoopThread, coro: Coroutine) -> None:
28-
self.loop.call_soon_threadsafe(self.loop.create_task, coro)
27+
def create_task(
28+
self: LoopThread,
29+
coro: Coroutine,
30+
*,
31+
callback: TaskCreatorCallback | None = None,
32+
) -> None:
33+
def _(
34+
coro: Coroutine,
35+
callback: Callable[[asyncio.Task], None] | None = None,
36+
) -> None:
37+
task = self.loop.create_task(coro)
38+
if callback:
39+
task.add_done_callback(callback)
40+
41+
self.loop.call_soon_threadsafe(_, coro, callback)
2942

3043

3144
@pytest.fixture

0 commit comments

Comments
 (0)