11from inspect import isawaitable
2- from typing import Any , AsyncIterable , AsyncIterator , Dict , Optional , Type , Union
2+ from typing import (
3+ Any ,
4+ AsyncIterable ,
5+ AsyncIterator ,
6+ Awaitable ,
7+ Dict ,
8+ Optional ,
9+ Type ,
10+ Union ,
11+ cast ,
12+ )
313
414from ..error import GraphQLError , located_error
515from ..execution .collect_fields import collect_fields
1121)
1222from ..execution .values import get_argument_values
1323from ..language import DocumentNode
14- from ..pyutils import Path , inspect
24+ from ..pyutils import AwaitableOrValue , Path , inspect
1525from ..type import GraphQLFieldResolver , GraphQLSchema
1626from .map_async_iterator import MapAsyncIterator
1727
1828
1929__all__ = ["subscribe" , "create_source_event_stream" ]
2030
2131
22- async def subscribe (
32+ def subscribe (
2333 schema : GraphQLSchema ,
2434 document : DocumentNode ,
2535 root_value : Any = None ,
@@ -29,7 +39,7 @@ async def subscribe(
2939 field_resolver : Optional [GraphQLFieldResolver ] = None ,
3040 subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
3141 execution_context_class : Optional [Type [ExecutionContext ]] = None ,
32- ) -> Union [AsyncIterator [ExecutionResult ], ExecutionResult ]:
42+ ) -> AwaitableOrValue [ Union [AsyncIterator [ExecutionResult ], ExecutionResult ] ]:
3343 """Create a GraphQL subscription.
3444
3545 Implements the "Subscribe" algorithm described in the GraphQL spec.
@@ -49,7 +59,7 @@ async def subscribe(
4959 If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
5060 a stream of ExecutionResults representing the response stream.
5161 """
52- result_or_stream = await create_source_event_stream (
62+ result_or_stream = create_source_event_stream (
5363 schema ,
5464 document ,
5565 root_value ,
@@ -59,8 +69,6 @@ async def subscribe(
5969 subscribe_field_resolver ,
6070 execution_context_class ,
6171 )
62- if isinstance (result_or_stream , ExecutionResult ):
63- return result_or_stream
6472
6573 async def map_source_to_response (payload : Any ) -> ExecutionResult :
6674 """Map source to response.
@@ -84,11 +92,28 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
8492 )
8593 return await result if isawaitable (result ) else result
8694
95+ if (execution_context_class or ExecutionContext ).is_awaitable (result_or_stream ):
96+ awaitable_result_or_stream = cast (Awaitable , result_or_stream )
97+
98+ # noinspection PyShadowingNames
99+ async def await_result () -> Any :
100+ result_or_stream = await awaitable_result_or_stream
101+ if isinstance (result_or_stream , ExecutionResult ):
102+ return result_or_stream
103+ return MapAsyncIterator (result_or_stream , map_source_to_response )
104+
105+ return await_result ()
106+
107+ if isinstance (result_or_stream , ExecutionResult ):
108+ return result_or_stream
109+
87110 # Map every source value to a ExecutionResult value as described above.
88- return MapAsyncIterator (result_or_stream , map_source_to_response )
111+ return MapAsyncIterator (
112+ cast (AsyncIterable [Any ], result_or_stream ), map_source_to_response
113+ )
89114
90115
91- async def create_source_event_stream (
116+ def create_source_event_stream (
92117 schema : GraphQLSchema ,
93118 document : DocumentNode ,
94119 root_value : Any = None ,
@@ -97,7 +122,7 @@ async def create_source_event_stream(
97122 operation_name : Optional [str ] = None ,
98123 subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
99124 execution_context_class : Optional [Type [ExecutionContext ]] = None ,
100- ) -> Union [AsyncIterable [Any ], ExecutionResult ]:
125+ ) -> AwaitableOrValue [ Union [AsyncIterable [Any ], ExecutionResult ] ]:
101126 """Create source event stream
102127
103128 Implements the "CreateSourceEventStream" algorithm described in the GraphQL
@@ -145,12 +170,28 @@ async def create_source_event_stream(
145170 return ExecutionResult (data = None , errors = context )
146171
147172 try :
148- return await execute_subscription (context )
173+ event_stream = execute_subscription (context )
149174 except GraphQLError as error :
150175 return ExecutionResult (data = None , errors = [error ])
151176
177+ if context .is_awaitable (event_stream ):
178+ awaitable_event_stream = cast (Awaitable , event_stream )
179+
180+ # noinspection PyShadowingNames
181+ async def await_event_stream () -> Union [AsyncIterable [Any ], ExecutionResult ]:
182+ try :
183+ return await awaitable_event_stream
184+ except GraphQLError as error :
185+ return ExecutionResult (data = None , errors = [error ])
152186
153- async def execute_subscription (context : ExecutionContext ) -> AsyncIterable [Any ]:
187+ return await_event_stream ()
188+
189+ return event_stream
190+
191+
192+ def execute_subscription (
193+ context : ExecutionContext ,
194+ ) -> AwaitableOrValue [AsyncIterable [Any ]]:
154195 schema = context .schema
155196
156197 root_type = schema .subscription_type
@@ -191,19 +232,33 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
191232 # AsyncIterable yielding raw payloads.
192233 resolve_fn = field_def .subscribe or context .subscribe_field_resolver
193234
194- event_stream = resolve_fn (context .root_value , info , ** args )
195- if context .is_awaitable (event_stream ):
196- event_stream = await event_stream
197- if isinstance (event_stream , Exception ):
198- raise event_stream
235+ result = resolve_fn (context .root_value , info , ** args )
236+ if context .is_awaitable (result ):
199237
200- # Assert field returned an event stream, otherwise yield an error.
201- if not isinstance (event_stream , AsyncIterable ):
202- raise GraphQLError (
203- "Subscription field must return AsyncIterable."
204- f" Received: { inspect (event_stream )} ."
205- )
238+ # noinspection PyShadowingNames
239+ async def await_result () -> AsyncIterable [Any ]:
240+ try :
241+ return assert_event_stream (await result )
242+ except Exception as error :
243+ raise located_error (error , field_nodes , path .as_list ())
244+
245+ return await_result ()
246+
247+ return assert_event_stream (result )
206248
207- return event_stream
208249 except Exception as error :
209250 raise located_error (error , field_nodes , path .as_list ())
251+
252+
253+ def assert_event_stream (result : Any ) -> AsyncIterable :
254+ if isinstance (result , Exception ):
255+ raise result
256+
257+ # Assert field returned an event stream, otherwise yield an error.
258+ if not isinstance (result , AsyncIterable ):
259+ raise GraphQLError (
260+ "Subscription field must return AsyncIterable."
261+ f" Received: { inspect (result )} ."
262+ )
263+
264+ return result
0 commit comments