Skip to content

Commit 09bd3e8

Browse files
backward compatibility with Python <=3.10
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
1 parent 1c7bef8 commit 09bd3e8

File tree

5 files changed

+113
-5
lines changed

5 files changed

+113
-5
lines changed

src/dispatch/asyncio.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import asyncio
2+
import functools
3+
import inspect
4+
import signal
5+
import threading
6+
7+
class Runner:
8+
"""Runner is a class similar to asyncio.Runner but that we use for backward
9+
compatibility with Python 3.10 and earlier.
10+
"""
11+
12+
def __init__(self):
13+
self._loop = asyncio.new_event_loop()
14+
self._interrupt_count = 0
15+
16+
def __enter__(self):
17+
return self
18+
19+
def __exit__(self, *args, **kwargs):
20+
self.close()
21+
22+
def close(self):
23+
try:
24+
loop = self._loop
25+
_cancel_all_tasks(loop)
26+
loop.run_until_complete(loop.shutdown_asyncgens())
27+
if hasattr(loop, 'shutdown_default_executor'): # Python 3.9+
28+
loop.run_until_complete(loop.shutdown_default_executor())
29+
finally:
30+
loop.close()
31+
32+
def get_loop(self):
33+
return self._loop
34+
35+
def run(self, coro):
36+
if not inspect.iscoroutine(coro):
37+
raise ValueError("a coroutine was expected, got {!r}".format(coro))
38+
39+
try:
40+
asyncio.get_running_loop()
41+
except RuntimeError:
42+
pass
43+
else:
44+
raise RuntimeError("Runner.run() cannot be called from a running event loop")
45+
46+
task = self._loop.create_task(coro)
47+
sigint_handler = None
48+
49+
if (threading.current_thread() is threading.main_thread()
50+
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
51+
):
52+
sigint_handler = functools.partial(self._on_sigint, main_task=task)
53+
try:
54+
signal.signal(signal.SIGINT, sigint_handler)
55+
except ValueError:
56+
# `signal.signal` may throw if `threading.main_thread` does
57+
# not support signals (e.g. embedded interpreter with signals
58+
# not registered - see gh-91880)
59+
sigint_handler = None
60+
61+
self._interrupt_count = 0
62+
try:
63+
asyncio.set_event_loop(self._loop)
64+
return self._loop.run_until_complete(task)
65+
except asyncio.CancelledError:
66+
if self._interrupt_count > 0:
67+
uncancel = getattr(task, "uncancel", None)
68+
if uncancel is not None and uncancel() == 0:
69+
raise KeyboardInterrupt()
70+
raise # CancelledError
71+
finally:
72+
asyncio.set_event_loop(None)
73+
if (sigint_handler is not None
74+
and signal.getsignal(signal.SIGINT) is sigint_handler
75+
):
76+
signal.signal(signal.SIGINT, signal.default_int_handler)
77+
78+
def _on_sigint(self, signum, frame, main_task):
79+
self._interrupt_count += 1
80+
if self._interrupt_count == 1 and not main_task.done():
81+
main_task.cancel()
82+
# wakeup loop if it is blocked by select() with long timeout
83+
self._loop.call_soon_threadsafe(lambda: None)
84+
return
85+
raise KeyboardInterrupt()
86+
87+
def _cancel_all_tasks(loop):
88+
to_cancel = asyncio.all_tasks(loop)
89+
if not to_cancel:
90+
return
91+
92+
for task in to_cancel:
93+
task.cancel()
94+
95+
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
96+
97+
for task in to_cancel:
98+
if task.cancelled():
99+
continue
100+
if task.exception() is not None:
101+
loop.call_exception_handler({
102+
'message': 'unhandled exception during asyncio.run() shutdown',
103+
'exception': task.exception(),
104+
'task': task,
105+
})

src/dispatch/experimental/lambda_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def handler(event, context):
2626

2727
from awslambdaric.lambda_context import LambdaContext
2828

29+
from dispatch.asyncio import Runner
2930
from dispatch.function import Registry
3031
from dispatch.proto import Input
3132
from dispatch.sdk.v1 import function_pb2 as function_pb
@@ -93,7 +94,7 @@ def handle(
9394

9495
input = Input(req)
9596
try:
96-
with asyncio.Runner() as runner:
97+
with Runner() as runner:
9798
output = runner.run(func._primitive_call(input))
9899
except Exception:
99100
logger.error("function '%s' fatal error", req.function, exc_info=True)

src/dispatch/flask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def read_root():
2323

2424
from flask import Flask, make_response, request
2525

26+
from dispatch.asyncio import Runner
2627
from dispatch.function import Registry
2728
from dispatch.http import FunctionServiceError, function_service_run
2829
from dispatch.signature import Ed25519PublicKey, parse_verification_key
@@ -90,7 +91,7 @@ def _handle_error(self, exc: FunctionServiceError):
9091
def _execute(self):
9192
data: bytes = request.get_data(cache=False)
9293

93-
with asyncio.Runner() as runner:
94+
with Runner() as runner:
9495
content = runner.run(
9596
function_service_run(
9697
request.url,

src/dispatch/http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from http_message_signatures import InvalidSignature
1111

12+
from dispatch.asyncio import Runner
1213
from dispatch.function import Registry
1314
from dispatch.proto import Input
1415
from dispatch.sdk.v1 import function_pb2 as function_pb
@@ -121,7 +122,7 @@ def do_POST(self):
121122
url = self.requestline # TODO: need full URL
122123

123124
try:
124-
with asyncio.Runner() as runner:
125+
with Runner() as runner:
125126
content = runner.run(
126127
function_service_run(
127128
url,

tests/dispatch/test_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import asyncio
21
import unittest
32
from typing import Any, Callable, List, Optional, Set, Type
43

4+
from dispatch.asyncio import Runner
55
from dispatch.coroutine import AnyException, any, call, gather, race
66
from dispatch.experimental.durable import durable
77
from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall
@@ -55,7 +55,7 @@ async def raises_error():
5555

5656
class TestOneShotScheduler(unittest.TestCase):
5757
def setUp(self):
58-
self.runner = asyncio.Runner()
58+
self.runner = Runner()
5959

6060
def tearDown(self):
6161
self.runner.close()

0 commit comments

Comments
 (0)