44from graphql import format_error , graphql
55
66from .constants import (
7+ GQL_COMPLETE ,
78 GQL_CONNECTION_ERROR ,
89 GQL_CONNECTION_INIT ,
910 GQL_CONNECTION_TERMINATE ,
1011 GQL_DATA ,
1112 GQL_ERROR ,
13+ GQL_NEXT ,
1214 GQL_START ,
1315 GQL_STOP ,
16+ GQL_SUBSCRIBE ,
17+ TRANSPORT_WS_PROTOCOL ,
1418)
1519
1620
@@ -23,6 +27,9 @@ def __init__(self, ws, request_context=None):
2327 self .ws = ws
2428 self .operations = {}
2529 self .request_context = request_context
30+ self .transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in (
31+ request_context .get ("subprotocols" ) or []
32+ )
2633
2734 def has_operation (self , op_id ):
2835 return op_id in self .operations
@@ -41,7 +48,7 @@ def remove_operation(self, op_id):
4148
4249 def unsubscribe (self , op_id ):
4350 async_iterator = self .remove_operation (op_id )
44- if hasattr (async_iterator , ' dispose' ):
51+ if hasattr (async_iterator , " dispose" ):
4552 async_iterator .dispose ()
4653 return async_iterator
4754
@@ -84,12 +91,16 @@ def process_message(self, connection_context, parsed_message):
8491 elif op_type == GQL_CONNECTION_TERMINATE :
8592 return self .on_connection_terminate (connection_context , op_id )
8693
87- elif op_type == GQL_START :
94+ elif op_type == (
95+ GQL_SUBSCRIBE if connection_context .transport_ws_protocol else GQL_START
96+ ):
8897 assert isinstance (payload , dict ), "The payload must be a dict"
8998 params = self .get_graphql_params (connection_context , payload )
9099 return self .on_start (connection_context , op_id , params )
91100
92- elif op_type == GQL_STOP :
101+ elif op_type == (
102+ GQL_COMPLETE if connection_context .transport_ws_protocol else GQL_STOP
103+ ):
93104 return self .on_stop (connection_context , op_id )
94105
95106 else :
@@ -142,7 +153,12 @@ def build_message(self, id, op_type, payload):
142153
143154 def send_execution_result (self , connection_context , op_id , execution_result ):
144155 result = self .execution_result_to_dict (execution_result )
145- return self .send_message (connection_context , op_id , GQL_DATA , result )
156+ return self .send_message (
157+ connection_context ,
158+ op_id ,
159+ GQL_NEXT if connection_context .transport_ws_protocol else GQL_DATA ,
160+ result ,
161+ )
146162
147163 def execution_result_to_dict (self , execution_result ):
148164 result = OrderedDict ()
0 commit comments