33
44from pytest import mark , raises
55
6- from graphql .execution import MapAsyncIterator , create_source_event_stream , subscribe
6+ from graphql .execution import (
7+ ExecutionResult ,
8+ MapAsyncIterator ,
9+ create_source_event_stream ,
10+ subscribe ,
11+ )
712from graphql .language import parse
813from graphql .pyutils import SimplePubSub
914from graphql .type import (
@@ -132,6 +137,22 @@ def transform(new_email):
132137DummyQueryType = GraphQLObjectType ("Query" , {"dummy" : GraphQLField (GraphQLString )})
133138
134139
140+ async def subscribe_with_bad_fn (subscribe_fn : Callable ) -> ExecutionResult :
141+ schema = GraphQLSchema (
142+ query = DummyQueryType ,
143+ subscription = GraphQLObjectType (
144+ "Subscription" ,
145+ {"foo" : GraphQLField (GraphQLString , subscribe = subscribe_fn )},
146+ ),
147+ )
148+ document = parse ("subscription { foo }" )
149+ result = await subscribe (schema , document )
150+
151+ assert isinstance (result , ExecutionResult )
152+ assert await create_source_event_stream (schema , document ) == result
153+ return result
154+
155+
135156# Check all error cases when initializing the subscription.
136157def describe_subscription_initialization_phase ():
137158 @mark .asyncio
@@ -333,43 +354,15 @@ async def should_pass_through_unexpected_errors_thrown_in_subscribe():
333354 @mark .asyncio
334355 @mark .filterwarnings ("ignore:.* was never awaited:RuntimeWarning" )
335356 async def throws_an_error_if_subscribe_does_not_return_an_iterator ():
336- schema = GraphQLSchema (
337- query = DummyQueryType ,
338- subscription = GraphQLObjectType (
339- "Subscription" ,
340- {
341- "foo" : GraphQLField (
342- GraphQLString , subscribe = lambda _obj , _info : "test"
343- )
344- },
345- ),
346- )
347-
348- document = parse ("subscription { foo }" )
349-
350357 with raises (TypeError ) as exc_info :
351- await subscribe ( schema , document )
358+ await subscribe_with_bad_fn ( lambda _obj , _info : "test" )
352359
353360 assert str (exc_info .value ) == (
354361 "Subscription field must return AsyncIterable. Received: 'test'."
355362 )
356363
357364 @mark .asyncio
358365 async def resolves_to_an_error_for_subscription_resolver_errors ():
359- async def subscribe_with_fn (subscribe_fn : Callable ):
360- schema = GraphQLSchema (
361- query = DummyQueryType ,
362- subscription = GraphQLObjectType (
363- "Subscription" ,
364- {"foo" : GraphQLField (GraphQLString , subscribe = subscribe_fn )},
365- ),
366- )
367- document = parse ("subscription { foo }" )
368- result = await subscribe (schema , document )
369-
370- assert await create_source_event_stream (schema , document ) == result
371- return result
372-
373366 expected_result = (
374367 None ,
375368 [
@@ -385,25 +378,25 @@ async def subscribe_with_fn(subscribe_fn: Callable):
385378 def return_error (_obj , _info ):
386379 return TypeError ("test error" )
387380
388- assert await subscribe_with_fn (return_error ) == expected_result
381+ assert await subscribe_with_bad_fn (return_error ) == expected_result
389382
390383 # Throwing an error
391384 def throw_error (* _args ):
392385 raise TypeError ("test error" )
393386
394- assert await subscribe_with_fn (throw_error ) == expected_result
387+ assert await subscribe_with_bad_fn (throw_error ) == expected_result
395388
396389 # Resolving to an error
397390 async def resolve_error (* _args ):
398391 return TypeError ("test error" )
399392
400- assert await subscribe_with_fn (resolve_error ) == expected_result
393+ assert await subscribe_with_bad_fn (resolve_error ) == expected_result
401394
402395 # Rejecting with an error
403396 async def reject_error (* _args ):
404397 return TypeError ("test error" )
405398
406- assert await subscribe_with_fn (reject_error ) == expected_result
399+ assert await subscribe_with_bad_fn (reject_error ) == expected_result
407400
408401 @mark .asyncio
409402 async def resolves_to_an_error_if_variables_were_wrong_type ():
0 commit comments