Skip to content

Commit 5f2cff5

Browse files
authored
Merge pull request #15 from ciscorn/fix-websockets
Fix handler for websockets lib not to ignore pending tasks
2 parents f7da106 + e624306 commit 5f2cff5

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

graphql_ws/aiohttp.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_graphql_params(self, *args, **kwargs):
5454
async def _handle(self, ws, request_context=None):
5555
connection_context = AiohttpConnectionContext(ws, request_context)
5656
await self.on_open(connection_context)
57-
pending_tasks = []
57+
pending = set()
5858
while True:
5959
try:
6060
if connection_context.closed:
@@ -63,18 +63,16 @@ async def _handle(self, ws, request_context=None):
6363
except ConnectionClosedException:
6464
break
6565
finally:
66-
pending_tasks = [t for t in pending_tasks if not t.done()]
66+
if pending:
67+
(_, pending) = await wait(pending, timeout=0, loop=self.loop)
6768

6869
task = ensure_future(
6970
self.on_message(connection_context, message), loop=self.loop)
70-
pending_tasks.append(task)
71+
pending.add(task)
7172

7273
self.on_close(connection_context)
73-
if pending_tasks:
74-
for task in pending_tasks:
75-
if not task.done():
76-
task.cancel()
77-
await wait(pending_tasks, loop=self.loop)
74+
for task in pending:
75+
task.cancel()
7876

7977
async def handle(self, ws, request_context=None):
8078
await shield(self._handle(ws, request_context), loop=self.loop)

graphql_ws/websockets_lib.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from inspect import isawaitable, isasyncgen
1+
from inspect import isawaitable
22

3-
from asyncio import ensure_future
3+
from asyncio import ensure_future, wait, shield
44
from websockets import ConnectionClosed
55
from graphql.execution.executors.asyncio import AsyncioExecutor
66

@@ -38,25 +38,40 @@ async def close(self, code):
3838

3939

4040
class WsLibSubscriptionServer(BaseSubscriptionServer):
41+
def __init__(self, schema, keep_alive=True, loop=None):
42+
self.loop = loop
43+
super().__init__(schema, keep_alive)
4144

4245
def get_graphql_params(self, *args, **kwargs):
4346
params = super(WsLibSubscriptionServer,
4447
self).get_graphql_params(*args, **kwargs)
45-
return dict(params, return_promise=True, executor=AsyncioExecutor())
48+
return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop))
4649

47-
async def handle(self, ws, request_context=None):
50+
async def _handle(self, ws, request_context):
4851
connection_context = WsLibConnectionContext(ws, request_context)
4952
await self.on_open(connection_context)
53+
pending = set()
5054
while True:
5155
try:
5256
if connection_context.closed:
5357
raise ConnectionClosedException()
5458
message = await connection_context.receive()
5559
except ConnectionClosedException:
56-
self.on_close(connection_context)
57-
return
60+
break
61+
finally:
62+
if pending:
63+
(_, pending) = await wait(pending, timeout=0, loop=self.loop)
64+
65+
task = ensure_future(
66+
self.on_message(connection_context, message), loop=self.loop)
67+
pending.add(task)
5868

59-
ensure_future(self.on_message(connection_context, message))
69+
self.on_close(connection_context)
70+
for task in pending:
71+
task.cancel()
72+
73+
async def handle(self, ws, request_context=None):
74+
await shield(self._handle(ws, request_context), loop=self.loop)
6075

6176
async def on_open(self, connection_context):
6277
pass

0 commit comments

Comments
 (0)