11from inspect import isawaitable
2- from typing import Any , AsyncIterable , AsyncIterator , Dict , Optional , Union
2+ from typing import Any , AsyncIterable , AsyncIterator , Dict , Optional , Type , Union
33
44from ..error import GraphQLError , located_error
55from ..execution .collect_fields import collect_fields
@@ -29,6 +29,7 @@ async def subscribe(
2929 operation_name : Optional [str ] = None ,
3030 field_resolver : Optional [GraphQLFieldResolver ] = None ,
3131 subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
32+ execution_context_class : Optional [Type ["ExecutionContext" ]] = None ,
3233) -> Union [AsyncIterator [ExecutionResult ], ExecutionResult ]:
3334 """Create a GraphQL subscription.
3435
@@ -57,6 +58,7 @@ async def subscribe(
5758 variable_values ,
5859 operation_name ,
5960 subscribe_field_resolver ,
61+ execution_context_class ,
6062 )
6163 if isinstance (result_or_stream , ExecutionResult ):
6264 return result_or_stream
@@ -79,6 +81,7 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
7981 variable_values ,
8082 operation_name ,
8183 field_resolver ,
84+ execution_context_class = execution_context_class ,
8285 )
8386 return await result if isawaitable (result ) else result
8487
@@ -94,6 +97,7 @@ async def create_source_event_stream(
9497 variable_values : Optional [Dict [str , Any ]] = None ,
9598 operation_name : Optional [str ] = None ,
9699 subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
100+ execution_context_class : Optional [Type ["ExecutionContext" ]] = None ,
97101) -> Union [AsyncIterable [Any ], ExecutionResult ]:
98102 """Create source event stream
99103
@@ -122,9 +126,12 @@ async def create_source_event_stream(
122126 # mistake which should throw an early error.
123127 assert_valid_execution_arguments (schema , document , variable_values )
124128
129+ if not execution_context_class :
130+ execution_context_class = ExecutionContext
131+
125132 # If a valid context cannot be created due to incorrect arguments,
126133 # a "Response" with only errors is returned.
127- context = ExecutionContext .build (
134+ context = execution_context_class .build (
128135 schema ,
129136 document ,
130137 root_value ,
0 commit comments