@@ -77,15 +77,20 @@ def test_terminate(self, ss, cc):
7777 ss .process_message (cc , {"id" : "1" , "type" : constants .GQL_CONNECTION_TERMINATE })
7878 ss .on_connection_terminate .assert_called_with (cc , "1" )
7979
80- def test_start (self , ss , cc ):
80+ @pytest .mark .parametrize (
81+ "transport_ws_protocol,expected_type" ,
82+ ((False , constants .GQL_START ), (True , constants .GQL_SUBSCRIBE )),
83+ )
84+ def test_start (self , ss , cc , transport_ws_protocol , expected_type ):
8185 ss .get_graphql_params = mock .Mock ()
8286 ss .get_graphql_params .return_value = {"params" : True }
8387 cc .has_operation = mock .Mock ()
8488 cc .has_operation .return_value = False
89+ cc .transport_ws_protocol = transport_ws_protocol
8590 ss .unsubscribe = mock .Mock ()
8691 ss .on_start = mock .Mock ()
8792 ss .process_message (
88- cc , {"id" : "1" , "type" : constants . GQL_START , "payload" : {"a" : "b" }}
93+ cc , {"id" : "1" , "type" : expected_type , "payload" : {"a" : "b" }}
8994 )
9095 assert not ss .unsubscribe .called
9196 ss .on_start .assert_called_with (cc , "1" , {"params" : True })
@@ -117,9 +122,32 @@ def test_start_bad_graphql_params(self, ss, cc):
117122 assert isinstance (ss .send_error .call_args [0 ][2 ], Exception )
118123 assert not ss .on_start .called
119124
120- def test_stop (self , ss , cc ):
125+ @pytest .mark .parametrize (
126+ "transport_ws_protocol,stop_type,invalid_stop_type" ,
127+ (
128+ (False , constants .GQL_STOP , constants .GQL_COMPLETE ),
129+ (True , constants .GQL_COMPLETE , constants .GQL_STOP ),
130+ ),
131+ )
132+ def test_stop (
133+ self ,
134+ ss ,
135+ cc ,
136+ transport_ws_protocol ,
137+ stop_type ,
138+ invalid_stop_type ,
139+ ):
121140 ss .on_stop = mock .Mock ()
122- ss .process_message (cc , {"id" : "1" , "type" : constants .GQL_STOP })
141+ ss .send_error = mock .Mock ()
142+ cc .transport_ws_protocol = transport_ws_protocol
143+
144+ ss .process_message (cc , {"id" : "1" , "type" : invalid_stop_type })
145+ assert ss .send_error .called
146+ assert ss .send_error .call_args [0 ][:2 ] == (cc , "1" )
147+ assert isinstance (ss .send_error .call_args [0 ][2 ], Exception )
148+ assert not ss .on_stop .called
149+
150+ ss .process_message (cc , {"id" : "1" , "type" : stop_type })
123151 ss .on_stop .assert_called_with (cc , "1" )
124152
125153 def test_invalid (self , ss , cc ):
@@ -165,13 +193,18 @@ def test_build_message_partial(ss):
165193 ss .build_message (id = None , op_type = None , payload = None )
166194
167195
168- def test_send_execution_result (ss , cc ):
196+ @pytest .mark .parametrize (
197+ "transport_ws_protocol,expected_type" ,
198+ ((False , constants .GQL_DATA ), (True , constants .GQL_NEXT )),
199+ )
200+ def test_send_execution_result (ss , cc , transport_ws_protocol , expected_type ):
201+ cc .transport_ws_protocol = transport_ws_protocol
169202 ss .execution_result_to_dict = mock .Mock ()
170203 ss .execution_result_to_dict .return_value = {"res" : "ult" }
171204 ss .send_message = mock .Mock ()
172205 ss .send_message .return_value = "returned"
173206 assert "returned" == ss .send_execution_result (cc , "1" , "result" )
174- ss .send_message .assert_called_with (cc , "1" , constants . GQL_DATA , {"res" : "ult" })
207+ ss .send_message .assert_called_with (cc , "1" , expected_type , {"res" : "ult" })
175208
176209
177210def test_execution_result_to_dict (ss ):
0 commit comments