44
55from asyncio import (
66 CancelledError ,
7+ TimeoutError , # only needed for Python < 3.11 # noqa: A004
78 ensure_future ,
89)
910from contextlib import suppress
2930 cast ,
3031)
3132
32- try :
33- from typing import TypeAlias
34- except ImportError : # Python < 3.10
35- from typing_extensions import TypeAlias
36- try : # only needed for Python < 3.11
37- from asyncio .exceptions import TimeoutError # noqa: A004
38- except ImportError : # Python < 3.7
39- from concurrent .futures import TimeoutError # noqa: A004
40-
4133from ..error import GraphQLError , located_error
4234from ..language import (
4335 DocumentNode ,
119111if TYPE_CHECKING :
120112 from graphql .pyutils .undefined import UndefinedType
121113
114+ try :
115+ from typing import TypeAlias , TypeGuard
116+ except ImportError : # Python < 3.10
117+ from typing_extensions import TypeAlias , TypeGuard
118+
122119try : # pragma: no cover
123120 anext # noqa: B018 # pyright: ignore
124121except NameError : # pragma: no cover (Python < 3.10)
@@ -216,7 +213,9 @@ class ExecutionContext(IncrementalPublisherContext):
216213 cancellable_streams : set [CancellableStreamRecord ] | None
217214 middleware_manager : MiddlewareManager | None
218215
219- is_awaitable : Callable [[Any ], bool ] = staticmethod (default_is_awaitable )
216+ is_awaitable : Callable [[Any ], TypeGuard [Awaitable ]] = staticmethod (
217+ default_is_awaitable # type: ignore
218+ )
220219
221220 def __init__ (
222221 self ,
@@ -230,7 +229,7 @@ def __init__(
230229 type_resolver : GraphQLTypeResolver ,
231230 subscribe_field_resolver : GraphQLFieldResolver ,
232231 middleware_manager : MiddlewareManager | None ,
233- is_awaitable : Callable [[Any ], bool ] | None ,
232+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ]] | None = None ,
234233 ) -> None :
235234 self .schema = schema
236235 self .fragments = fragments
@@ -242,8 +241,7 @@ def __init__(
242241 self .type_resolver = type_resolver
243242 self .subscribe_field_resolver = subscribe_field_resolver
244243 self .middleware_manager = middleware_manager
245- if is_awaitable :
246- self .is_awaitable = is_awaitable
244+ self .is_awaitable = is_awaitable or default_is_awaitable
247245 self .errors = None
248246 self .cancellable_streams = None
249247 self ._canceled_iterators : set [AsyncIterator ] = set ()
@@ -264,7 +262,7 @@ def build(
264262 type_resolver : GraphQLTypeResolver | None = None ,
265263 subscribe_field_resolver : GraphQLFieldResolver | None = None ,
266264 middleware : Middleware | None = None ,
267- is_awaitable : Callable [[Any ], bool ] | None = None ,
265+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ] ] | None = None ,
268266 ** custom_args : Any ,
269267 ) -> list [GraphQLError ] | ExecutionContext :
270268 """Build an execution context
@@ -422,7 +420,7 @@ async def await_result() -> (
422420 ExecutionResult | ExperimentalIncrementalExecutionResults
423421 ):
424422 try :
425- resolved = await graphql_wrapped_result # type: ignore
423+ resolved = await graphql_wrapped_result
426424 except GraphQLError as error :
427425 return ExecutionResult (None , with_error (self .errors , error ))
428426 return self .build_data_response (
@@ -496,7 +494,7 @@ def reducer(
496494 if is_awaitable (result ):
497495
498496 async def set_result () -> GraphQLWrappedResult [dict [str , Any ]]:
499- resolved = await result # type: ignore
497+ resolved = await result
500498 graphql_wrapped_result .result [response_name ] = resolved .result
501499 graphql_wrapped_result .add_increments (resolved .increments )
502500 return graphql_wrapped_result
@@ -553,11 +551,12 @@ async def resolve(
553551 add_increments (resolved .increments )
554552 return resolved .result
555553
556- results [response_name ] = resolve (result ) # type: ignore
554+ results [response_name ] = resolve (result )
557555 append_awaitable (response_name )
558556 else :
559- results [response_name ] = result .result # type: ignore
560- add_increments (result .increments ) # type: ignore
557+ result = cast ("GraphQLWrappedResult[dict[str, Any]]" , result )
558+ results [response_name ] = result .result
559+ add_increments (result .increments )
561560
562561 # If there are no coroutines, we can just return the object.
563562 if not awaitable_fields :
@@ -651,7 +650,7 @@ def execute_field(
651650 # noinspection PyShadowingNames
652651 async def await_completed () -> Any :
653652 try :
654- return await completed # type: ignore
653+ return await completed
655654 except Exception as raw_error :
656655 # Before Python 3.8 CancelledError inherits Exception and
657656 # so gets caught here.
@@ -864,7 +863,7 @@ async def complete_awaitable_value(
864863 defer_map ,
865864 )
866865 if self .is_awaitable (completed ):
867- completed = await completed # type: ignore
866+ completed = await completed
868867 except Exception as raw_error :
869868 # Before Python 3.8 CancelledError inherits Exception and
870869 # so gets caught here.
@@ -1276,8 +1275,9 @@ async def complete_awaitable_list_item_value(
12761275 defer_map ,
12771276 )
12781277 if self .is_awaitable (completed ):
1279- completed = await completed # type: ignore
1280- parent .add_increments (completed .increments ) # type: ignore
1278+ completed = await completed
1279+ completed = cast ("GraphQLWrappedResult[list[Any]]" , completed )
1280+ parent .add_increments (completed .increments )
12811281 except Exception as raw_error :
12821282 self .handle_field_error (
12831283 raw_error ,
@@ -1288,7 +1288,7 @@ async def complete_awaitable_list_item_value(
12881288 )
12891289 return None
12901290 else :
1291- return completed .result # type: ignore
1291+ return completed .result
12921292
12931293 @staticmethod
12941294 def complete_leaf_value (return_type : GraphQLLeafType , result : Any ) -> Any :
@@ -1326,7 +1326,6 @@ def complete_abstract_value(
13261326 runtime_type = resolve_type_fn (result , info , return_type )
13271327
13281328 if self .is_awaitable (runtime_type ):
1329- runtime_type = cast ("Awaitable" , runtime_type )
13301329
13311330 async def await_complete_object_value () -> Any :
13321331 value = self .complete_object_value (
@@ -1345,7 +1344,7 @@ async def await_complete_object_value() -> Any:
13451344 defer_map ,
13461345 )
13471346 if self .is_awaitable (value ):
1348- return await value # type: ignore
1347+ return await value
13491348 return value # pragma: no cover
13501349
13511350 return await_complete_object_value ()
@@ -1447,7 +1446,7 @@ def complete_object_value(
14471446 async def execute_subfields_async () -> GraphQLWrappedResult [
14481447 dict [str , Any ]
14491448 ]:
1450- if not await is_type_of : # type: ignore
1449+ if not await is_type_of :
14511450 raise invalid_return_type_error (
14521451 return_type , result , field_group
14531452 )
@@ -1460,8 +1459,10 @@ async def execute_subfields_async() -> GraphQLWrappedResult[
14601459 defer_map ,
14611460 )
14621461 if self .is_awaitable (graphql_wrapped_result ): # pragma: no cover
1463- return await graphql_wrapped_result # type: ignore
1464- return graphql_wrapped_result # type: ignore
1462+ return await graphql_wrapped_result
1463+ return cast (
1464+ "GraphQLWrappedResult[dict[str, Any]]" , graphql_wrapped_result
1465+ )
14651466
14661467 return execute_subfields_async ()
14671468
@@ -1644,8 +1645,8 @@ async def executor(
16441645 defer_map ,
16451646 )
16461647 if self .is_awaitable (result ):
1647- return await result # type: ignore
1648- return result # type: ignore
1648+ return await result
1649+ return cast ( "DeferredGroupedFieldSetResult" , result )
16491650
16501651 deferred_grouped_field_set_record = DeferredGroupedFieldSetRecord (
16511652 deferred_fragment_records ,
@@ -1702,7 +1703,7 @@ def execute_deferred_grouped_field_set(
17021703
17031704 async def await_result () -> DeferredGroupedFieldSetResult :
17041705 try :
1705- awaited_result = await result # type: ignore
1706+ awaited_result = await result
17061707 except GraphQLError as error :
17071708 return NonReconcilableDeferredGroupedFieldSetResult (
17081709 deferred_fragment_records ,
@@ -1792,8 +1793,8 @@ async def await_result() -> StreamItemsResult:
17921793
17931794 result = first_stream_items .result
17941795 if is_awaitable (result ):
1795- return await result # type: ignore
1796- return result # type: ignore
1796+ return await result
1797+ return cast ( "StreamItemsResult" , result )
17971798
17981799 return StreamItemsRecord (stream_record , await_result ())
17991800
@@ -1864,8 +1865,8 @@ async def get_next_async_stream_items_result(
18641865 result = self .prepend_next_stream_items (result , next_stream_items_record )
18651866
18661867 if self .is_awaitable (result ):
1867- return await result # type: ignore
1868- return result # type: ignore
1868+ return await result
1869+ return cast ( "StreamItemsResult" , result )
18691870
18701871 def complete_stream_items (
18711872 self ,
@@ -1932,7 +1933,7 @@ async def await_item() -> StreamItemsResult:
19321933 async def await_item () -> StreamItemsResult :
19331934 try :
19341935 try :
1935- awaited_item = await result # type: ignore
1936+ awaited_item = await result
19361937 except Exception as raw_error :
19371938 self .handle_field_error (
19381939 raw_error ,
@@ -1967,13 +1968,13 @@ def prepend_next_stream_items(
19671968 if self .is_awaitable (result ):
19681969
19691970 async def await_result () -> StreamItemsResult :
1970- resolved = await result # type: ignore
1971+ resolved = await result
19711972 return prepend_next_resolved_stream_items (resolved , next_stream_items )
19721973
19731974 return await_result ()
19741975
19751976 return prepend_next_resolved_stream_items (
1976- result , # type: ignore
1977+ cast ( "StreamItemsResult" , result ),
19771978 next_stream_items ,
19781979 )
19791980
@@ -1986,7 +1987,7 @@ def with_new_deferred_grouped_field_sets(
19861987 if self .is_awaitable (result ):
19871988
19881989 async def await_result () -> GraphQLWrappedResult [dict [str , Any ]]:
1989- resolved = await result # type: ignore
1990+ resolved = await result
19901991 resolved .add_increments (new_deferred_grouped_field_set_records )
19911992 return resolved
19921993
@@ -2091,7 +2092,7 @@ def execute(
20912092 subscribe_field_resolver : GraphQLFieldResolver | None = None ,
20922093 middleware : Middleware | None = None ,
20932094 execution_context_class : type [ExecutionContext ] | None = None ,
2094- is_awaitable : Callable [[Any ], bool ] | None = None ,
2095+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ] ] | None = None ,
20952096 ** custom_context_args : Any ,
20962097) -> AwaitableOrValue [ExecutionResult ]:
20972098 """Execute a GraphQL operation.
@@ -2153,7 +2154,7 @@ def experimental_execute_incrementally(
21532154 subscribe_field_resolver : GraphQLFieldResolver | None = None ,
21542155 middleware : Middleware | None = None ,
21552156 execution_context_class : type [ExecutionContext ] | None = None ,
2156- is_awaitable : Callable [[Any ], bool ] | None = None ,
2157+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ] ] | None = None ,
21572158 ** custom_context_args : Any ,
21582159) -> AwaitableOrValue [ExecutionResult | ExperimentalIncrementalExecutionResults ]:
21592160 """Execute GraphQL operation incrementally (internal implementation).
@@ -2193,7 +2194,7 @@ def experimental_execute_incrementally(
21932194 return context .execute_operation ()
21942195
21952196
2196- def assume_not_awaitable (_value : Any ) -> bool :
2197+ def assume_not_awaitable (_value : Any ) -> TypeGuard [ Awaitable ] :
21972198 """Replacement for is_awaitable if everything is assumed to be synchronous."""
21982199 return False
21992200
@@ -2221,7 +2222,7 @@ def execute_sync(
22212222 Set check_sync to True to still run checks that no awaitable values are returned.
22222223 """
22232224 is_awaitable = (
2224- check_sync
2225+ cast ( "Callable[[Any], TypeGuard[Awaitable]]" , check_sync )
22252226 if callable (check_sync )
22262227 else (None if check_sync else assume_not_awaitable )
22272228 )
@@ -2434,7 +2435,7 @@ def default_type_resolver(
24342435 return type_ .name
24352436
24362437 if awaitable_is_type_of_results :
2437- # noinspection PyShadowingNames
2438+
24382439 async def get_type () -> str | None :
24392440 is_type_of_results = await gather_with_cancel (* awaitable_is_type_of_results )
24402441 for is_type_of_result , type_ in zip (is_type_of_results , awaitable_types ):
@@ -2533,9 +2534,9 @@ def subscribe(
25332534 result_or_stream = create_source_event_stream_impl (context )
25342535
25352536 if context .is_awaitable (result_or_stream ):
2536- # noinspection PyShadowingNames
2537+
25372538 async def await_result () -> Any :
2538- awaited_result_or_stream = await result_or_stream # type: ignore
2539+ awaited_result_or_stream = await result_or_stream
25392540 if isinstance (awaited_result_or_stream , ExecutionResult ):
25402541 return awaited_result_or_stream
25412542 return context .map_source_to_response (awaited_result_or_stream )
@@ -2616,12 +2617,10 @@ def create_source_event_stream_impl(
26162617 return ExecutionResult (None , errors = [error ])
26172618
26182619 if context .is_awaitable (event_stream ):
2619- awaitable_event_stream = cast ("Awaitable" , event_stream )
26202620
2621- # noinspection PyShadowingNames
26222621 async def await_event_stream () -> AsyncIterable [Any ] | ExecutionResult :
26232622 try :
2624- return await awaitable_event_stream
2623+ return await event_stream
26252624 except GraphQLError as error :
26262625 return ExecutionResult (None , errors = [error ])
26272626
0 commit comments