Skip to content

Commit ad3713c

Browse files
committed
Add websockets library context and server classes
1 parent 5f5e85e commit ad3713c

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

graphql_ws/websockets_lib.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from inspect import isawaitable, isasyncgen
2+
3+
from asyncio import ensure_future
4+
from websockets import ConnectionClosed
5+
from graphql.execution.executors.asyncio import AsyncioExecutor
6+
7+
from .base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer
8+
from .observable_aiter import setup_observable_extension
9+
10+
from .constants import (
11+
GQL_CONNECTION_ACK,
12+
GQL_CONNECTION_ERROR,
13+
GQL_COMPLETE
14+
)
15+
16+
setup_observable_extension()
17+
18+
19+
class WsLibConnectionContext(BaseConnectionContext):
20+
async def receive(self):
21+
try:
22+
msg = await self.ws.recv()
23+
return msg
24+
except ConnectionClosed:
25+
raise ConnectionClosedException()
26+
27+
async def send(self, data):
28+
if self.closed:
29+
return
30+
await self.ws.send(data)
31+
32+
@property
33+
def closed(self):
34+
return self.ws.open is False
35+
36+
async def close(self, code):
37+
await self.ws.close(code)
38+
39+
40+
class WsLibSubscriptionServer(BaseSubscriptionServer):
41+
42+
def get_graphql_params(self, *args, **kwargs):
43+
params = super(WsLibSubscriptionServer,
44+
self).get_graphql_params(*args, **kwargs)
45+
return dict(params, return_promise=True, executor=AsyncioExecutor())
46+
47+
async def handle(self, ws, request_context=None):
48+
connection_context = WsLibConnectionContext(ws, request_context)
49+
await self.on_open(connection_context)
50+
while True:
51+
try:
52+
if connection_context.closed:
53+
raise ConnectionClosedException()
54+
message = await connection_context.receive()
55+
except ConnectionClosedException:
56+
self.on_close(connection_context)
57+
return
58+
59+
ensure_future(self.on_message(connection_context, message))
60+
61+
async def on_open(self, connection_context):
62+
pass
63+
64+
def on_close(self, connection_context):
65+
remove_operations = list(connection_context.operations.keys())
66+
for op_id in remove_operations:
67+
self.unsubscribe(connection_context, op_id)
68+
69+
async def on_connect(self, connection_context, payload):
70+
pass
71+
72+
async def on_connection_init(self, connection_context, op_id, payload):
73+
try:
74+
await self.on_connect(connection_context, payload)
75+
await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK)
76+
except Exception as e:
77+
await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
78+
await connection_context.close(1011)
79+
80+
async def on_start(self, connection_context, op_id, params):
81+
execution_result = self.execute(
82+
connection_context.request_context, params)
83+
84+
if isawaitable(execution_result):
85+
execution_result = await execution_result
86+
87+
if not hasattr(execution_result, '__aiter__'):
88+
await self.send_execution_result(connection_context, op_id, execution_result)
89+
else:
90+
iterator = await execution_result.__aiter__()
91+
connection_context.register_operation(op_id, iterator)
92+
async for single_result in iterator:
93+
if not connection_context.has_operation(op_id):
94+
break
95+
await self.send_execution_result(connection_context, op_id, single_result)
96+
await self.send_message(connection_context, op_id, GQL_COMPLETE)
97+
98+
async def on_stop(self, connection_context, op_id):
99+
self.unsubscribe(connection_context, op_id)

0 commit comments

Comments
 (0)