|
1 | | -from inspect import isawaitable, isasyncgen |
| 1 | +from inspect import isawaitable |
2 | 2 |
|
3 | | -from asyncio import ensure_future |
| 3 | +from asyncio import ensure_future, wait, shield |
4 | 4 | from websockets import ConnectionClosed |
5 | 5 | from graphql.execution.executors.asyncio import AsyncioExecutor |
6 | 6 |
|
@@ -38,25 +38,40 @@ async def close(self, code): |
38 | 38 |
|
39 | 39 |
|
40 | 40 | class WsLibSubscriptionServer(BaseSubscriptionServer): |
| 41 | + def __init__(self, schema, keep_alive=True, loop=None): |
| 42 | + self.loop = loop |
| 43 | + super().__init__(schema, keep_alive) |
41 | 44 |
|
42 | 45 | def get_graphql_params(self, *args, **kwargs): |
43 | 46 | params = super(WsLibSubscriptionServer, |
44 | 47 | self).get_graphql_params(*args, **kwargs) |
45 | | - return dict(params, return_promise=True, executor=AsyncioExecutor()) |
| 48 | + return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop)) |
46 | 49 |
|
47 | | - async def handle(self, ws, request_context=None): |
| 50 | + async def _handle(self, ws, request_context): |
48 | 51 | connection_context = WsLibConnectionContext(ws, request_context) |
49 | 52 | await self.on_open(connection_context) |
| 53 | + pending = set() |
50 | 54 | while True: |
51 | 55 | try: |
52 | 56 | if connection_context.closed: |
53 | 57 | raise ConnectionClosedException() |
54 | 58 | message = await connection_context.receive() |
55 | 59 | except ConnectionClosedException: |
56 | | - self.on_close(connection_context) |
57 | | - return |
| 60 | + break |
| 61 | + finally: |
| 62 | + if pending: |
| 63 | + (_, pending) = await wait(pending, timeout=0, loop=self.loop) |
| 64 | + |
| 65 | + task = ensure_future( |
| 66 | + self.on_message(connection_context, message), loop=self.loop) |
| 67 | + pending.add(task) |
58 | 68 |
|
59 | | - ensure_future(self.on_message(connection_context, message)) |
| 69 | + self.on_close(connection_context) |
| 70 | + for task in pending: |
| 71 | + task.cancel() |
| 72 | + |
| 73 | + async def handle(self, ws, request_context=None): |
| 74 | + await shield(self._handle(ws, request_context), loop=self.loop) |
60 | 75 |
|
61 | 76 | async def on_open(self, connection_context): |
62 | 77 | pass |
|
0 commit comments