@@ -331,6 +331,23 @@ def build_response(
331331 )
332332 return ExecutionResult (data , errors )
333333
334+ def build_per_event_execution_context (self , payload : Any ) -> ExecutionContext :
335+ """Create a copy of the execution context for usage with subscribe events."""
336+ return self .__class__ (
337+ self .schema ,
338+ self .fragments ,
339+ payload ,
340+ self .context_value ,
341+ self .operation ,
342+ self .variable_values ,
343+ self .field_resolver ,
344+ self .type_resolver ,
345+ self .subscribe_field_resolver ,
346+ [],
347+ self .middleware_manager ,
348+ self .is_awaitable ,
349+ )
350+
334351 def execute_operation (self ) -> AwaitableOrValue [Any ]:
335352 """Execute an operation.
336353
@@ -1003,7 +1020,7 @@ def execute(
10031020
10041021 # If a valid execution context cannot be created due to incorrect arguments,
10051022 # a "Response" with only errors is returned.
1006- exe_context = execution_context_class .build (
1023+ context = execution_context_class .build (
10071024 schema ,
10081025 document ,
10091026 root_value ,
@@ -1018,9 +1035,14 @@ def execute(
10181035 )
10191036
10201037 # Return early errors if execution context failed.
1021- if isinstance (exe_context , list ):
1022- return ExecutionResult (data = None , errors = exe_context )
1038+ if isinstance (context , list ):
1039+ return ExecutionResult (data = None , errors = context )
1040+
1041+ return execute_impl (context )
1042+
10231043
1044+ def execute_impl (context : ExecutionContext ) -> AwaitableOrValue [ExecutionResult ]:
1045+ """Execute GraphQL operation (internal implementation)."""
10241046 # Return a possible coroutine object that will eventually yield the data described
10251047 # by the "Response" section of the GraphQL specification.
10261048 #
@@ -1032,12 +1054,12 @@ def execute(
10321054 # Errors from sub-fields of a NonNull type may propagate to the top level,
10331055 # at which point we still log the error and null the parent field, which
10341056 # in this case is the entire response.
1035- errors = exe_context .errors
1036- build_response = exe_context .build_response
1057+ errors = context .errors
1058+ build_response = context .build_response
10371059 try :
1038- result = exe_context .execute_operation ()
1060+ result = context .execute_operation ()
10391061
1040- if exe_context .is_awaitable (result ):
1062+ if context .is_awaitable (result ):
10411063 # noinspection PyShadowingNames
10421064 async def await_result () -> Any :
10431065 try :
@@ -1215,6 +1237,7 @@ def subscribe(
12151237 variable_values : Optional [Dict [str , Any ]] = None ,
12161238 operation_name : Optional [str ] = None ,
12171239 field_resolver : Optional [GraphQLFieldResolver ] = None ,
1240+ type_resolver : Optional [GraphQLTypeResolver ] = None ,
12181241 subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
12191242 execution_context_class : Optional [Type [ExecutionContext ]] = None ,
12201243) -> AwaitableOrValue [Union [AsyncIterator [ExecutionResult ], ExecutionResult ]]:
@@ -1237,17 +1260,31 @@ def subscribe(
12371260 If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
12381261 a stream of ExecutionResults representing the response stream.
12391262 """
1240- result_or_stream = create_source_event_stream (
1263+ if execution_context_class is None :
1264+ execution_context_class = ExecutionContext
1265+
1266+ # If a valid context cannot be created due to incorrect arguments,
1267+ # a "Response" with only errors is returned.
1268+ context = execution_context_class .build (
12411269 schema ,
12421270 document ,
12431271 root_value ,
12441272 context_value ,
12451273 variable_values ,
12461274 operation_name ,
1275+ field_resolver ,
1276+ type_resolver ,
12471277 subscribe_field_resolver ,
1248- execution_context_class ,
12491278 )
12501279
1280+ # Return early errors if execution context failed.
1281+ if isinstance (context , list ):
1282+ return ExecutionResult (data = None , errors = context )
1283+
1284+ result_or_stream = create_source_event_stream_impl (context )
1285+
1286+ build_context = context .build_per_event_execution_context
1287+
12511288 async def map_source_to_response (payload : Any ) -> ExecutionResult :
12521289 """Map source to response.
12531290
@@ -1258,19 +1295,10 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
12581295 "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
12591296 "ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
12601297 """
1261- result = execute (
1262- schema ,
1263- document ,
1264- payload ,
1265- context_value ,
1266- variable_values ,
1267- operation_name ,
1268- field_resolver ,
1269- execution_context_class = execution_context_class ,
1270- )
1298+ result = execute_impl (build_context (payload ))
12711299 return await result if isawaitable (result ) else result
12721300
1273- if ( execution_context_class or ExecutionContext ) .is_awaitable (result_or_stream ):
1301+ if execution_context_class .is_awaitable (result_or_stream ):
12741302 awaitable_result_or_stream = cast (Awaitable , result_or_stream )
12751303
12761304 # noinspection PyShadowingNames
@@ -1298,6 +1326,8 @@ def create_source_event_stream(
12981326 context_value : Any = None ,
12991327 variable_values : Optional [Dict [str , Any ]] = None ,
13001328 operation_name : Optional [str ] = None ,
1329+ field_resolver : Optional [GraphQLFieldResolver ] = None ,
1330+ type_resolver : Optional [GraphQLTypeResolver ] = None ,
13011331 subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
13021332 execution_context_class : Optional [Type [ExecutionContext ]] = None ,
13031333) -> AwaitableOrValue [Union [AsyncIterable [Any ], ExecutionResult ]]:
@@ -1336,13 +1366,22 @@ def create_source_event_stream(
13361366 context_value ,
13371367 variable_values ,
13381368 operation_name ,
1339- subscribe_field_resolver = subscribe_field_resolver ,
1369+ field_resolver ,
1370+ type_resolver ,
1371+ subscribe_field_resolver ,
13401372 )
13411373
13421374 # Return early errors if execution context failed.
13431375 if isinstance (context , list ):
13441376 return ExecutionResult (data = None , errors = context )
13451377
1378+ return create_source_event_stream_impl (context )
1379+
1380+
1381+ def create_source_event_stream_impl (
1382+ context : ExecutionContext ,
1383+ ) -> AwaitableOrValue [Union [AsyncIterable [Any ], ExecutionResult ]]:
1384+ """Create source event stream (internal implementation)."""
13461385 try :
13471386 event_stream = execute_subscription (context )
13481387 except GraphQLError as error :
0 commit comments