11from inspect import isawaitable , isasyncgen
2+ from asyncio import ensure_future , wait , shield
23
3- from asyncio import ensure_future
44from aiohttp import WSMsgType
55from graphql .execution .executors .asyncio import AsyncioExecutor
66
@@ -23,6 +23,10 @@ async def receive(self):
2323 return msg .data
2424 elif msg .type == WSMsgType .ERROR :
2525 raise ConnectionClosedException ()
26+ elif msg .type == WSMsgType .CLOSING :
27+ raise ConnectionClosedException ()
28+ elif msg .type == WSMsgType .CLOSED :
29+ raise ConnectionClosedException ()
2630
2731 async def send (self , data ):
2832 if self .closed :
@@ -38,25 +42,42 @@ async def close(self, code):
3842
3943
4044class AiohttpSubscriptionServer (BaseSubscriptionServer ):
45+ def __init__ (self , schema , keep_alive = True , loop = None ):
46+ self .loop = loop
47+ super ().__init__ (schema , keep_alive )
4148
4249 def get_graphql_params (self , * args , ** kwargs ):
4350 params = super (AiohttpSubscriptionServer ,
4451 self ).get_graphql_params (* args , ** kwargs )
45- return dict (params , return_promise = True , executor = AsyncioExecutor ())
52+ return dict (params , return_promise = True , executor = AsyncioExecutor (loop = self . loop ))
4653
47- async def handle (self , ws , request_context = None ):
54+ async def _handle (self , ws , request_context = None ):
4855 connection_context = AiohttpConnectionContext (ws , request_context )
4956 await self .on_open (connection_context )
57+ pending_tasks = []
5058 while True :
5159 try :
5260 if connection_context .closed :
5361 raise ConnectionClosedException ()
5462 message = await connection_context .receive ()
5563 except ConnectionClosedException :
56- self .on_close (connection_context )
57- return
64+ break
65+ finally :
66+ pending_tasks = [t for t in pending_tasks if not t .done ()]
5867
59- ensure_future (self .on_message (connection_context , message ))
68+ task = ensure_future (
69+ self .on_message (connection_context , message ), loop = self .loop )
70+ pending_tasks .append (task )
71+
72+ self .on_close (connection_context )
73+ if pending_tasks :
74+ for task in pending_tasks :
75+ if not task .done ():
76+ task .cancel ()
77+ await wait (pending_tasks , loop = self .loop )
78+
79+ async def handle (self , ws , request_context = None ):
80+ await shield (self ._handle (ws , request_context ), loop = self .loop )
6081
6182 async def on_open (self , connection_context ):
6283 pass
0 commit comments