22
33from asyncio import Event , as_completed , ensure_future , gather , shield , sleep , wait_for
44from collections .abc import Mapping
5+ from contextlib import suppress
56from inspect import isawaitable
67from typing import (
78 Any ,
1718 NamedTuple ,
1819 Optional ,
1920 Sequence ,
21+ Set ,
2022 Tuple ,
2123 Type ,
2224 Union ,
@@ -673,6 +675,7 @@ def __init__(
673675 self .middleware_manager = middleware_manager
674676 if is_awaitable :
675677 self .is_awaitable = is_awaitable
678+ self ._canceled_iterators : Set [AsyncIterator ] = set ()
676679 self ._subfields_cache : Dict [Tuple , FieldsAndPatches ] = {}
677680
678681 @classmethod
@@ -1006,6 +1009,7 @@ async def await_completed() -> Any:
10061009 except Exception as raw_error :
10071010 error = located_error (raw_error , field_nodes , path .as_list ())
10081011 handle_field_error (error , return_type , errors )
1012+ self .filter_subsequent_payloads (path )
10091013 return None
10101014
10111015 return await_completed ()
@@ -1014,6 +1018,7 @@ async def await_completed() -> Any:
10141018 except Exception as raw_error :
10151019 error = located_error (raw_error , field_nodes , path .as_list ())
10161020 handle_field_error (error , return_type , errors )
1021+ self .filter_subsequent_payloads (path )
10171022 return None
10181023
10191024 def build_resolve_info (
@@ -1305,6 +1310,7 @@ def complete_list_value(
13051310 and index >= stream .initial_count
13061311 ):
13071312 previous_async_payload_record = self .execute_stream_field (
1313+ path ,
13081314 item_path ,
13091315 item ,
13101316 field_nodes ,
@@ -1334,6 +1340,7 @@ async def await_completed(item: Any, item_path: Path) -> Any:
13341340 raw_error , field_nodes , item_path .as_list ()
13351341 )
13361342 handle_field_error (error , item_type , errors )
1343+ self .filter_subsequent_payloads (item_path )
13371344 return None
13381345
13391346 completed_item = await_completed (item , item_path )
@@ -1357,12 +1364,14 @@ async def await_completed(item: Any, item_path: Path) -> Any:
13571364 raw_error , field_nodes , item_path .as_list ()
13581365 )
13591366 handle_field_error (error , item_type , errors )
1367+ self .filter_subsequent_payloads (item_path )
13601368 return None
13611369
13621370 completed_item = await_completed (completed_item , item_path )
13631371 except Exception as raw_error :
13641372 error = located_error (raw_error , field_nodes , item_path .as_list ())
13651373 handle_field_error (error , item_type , errors )
1374+ self .filter_subsequent_payloads (item_path )
13661375 completed_item = None
13671376
13681377 if is_awaitable (completed_item ):
@@ -1694,14 +1703,17 @@ async def await_data(
16941703 def execute_stream_field (
16951704 self ,
16961705 path : Path ,
1706+ item_path : Path ,
16971707 item : AwaitableOrValue [Any ],
16981708 field_nodes : List [FieldNode ],
16991709 info : GraphQLResolveInfo ,
17001710 item_type : GraphQLOutputType ,
17011711 label : Optional [str ] = None ,
17021712 parent_context : Optional [AsyncPayloadRecord ] = None ,
17031713 ) -> AsyncPayloadRecord :
1704- async_payload_record = StreamRecord (label , path , None , parent_context , self )
1714+ async_payload_record = StreamRecord (
1715+ label , item_path , None , parent_context , self
1716+ )
17051717 completed_item : Any
17061718 completed_items : Any
17071719 try :
@@ -1713,7 +1725,7 @@ async def await_completed_item() -> Any:
17131725 item_type ,
17141726 field_nodes ,
17151727 info ,
1716- path ,
1728+ item_path ,
17171729 await item ,
17181730 async_payload_record ,
17191731 )
@@ -1727,7 +1739,12 @@ async def await_completed_item() -> Any:
17271739
17281740 else :
17291741 completed_item = self .complete_value (
1730- item_type , field_nodes , info , path , item , async_payload_record
1742+ item_type ,
1743+ field_nodes ,
1744+ info ,
1745+ item_path ,
1746+ item ,
1747+ async_payload_record ,
17311748 )
17321749
17331750 if self .is_awaitable (completed_item ):
@@ -1739,24 +1756,31 @@ async def await_completed_item() -> Any:
17391756 except Exception as raw_error :
17401757 # noinspection PyShadowingNames
17411758 error = located_error (
1742- raw_error , field_nodes , path .as_list ()
1759+ raw_error , field_nodes , item_path .as_list ()
17431760 )
17441761 handle_field_error (
17451762 error , item_type , async_payload_record .errors
17461763 )
1764+ self .filter_subsequent_payloads (
1765+ item_path , async_payload_record
1766+ )
17471767 return None
17481768
17491769 complete_item = await_completed_item ()
17501770
17511771 else :
17521772 complete_item = completed_item
17531773 except Exception as raw_error :
1754- error = located_error (raw_error , field_nodes , path .as_list ())
1774+ error = located_error (raw_error , field_nodes , item_path .as_list ())
17551775 handle_field_error (error , item_type , async_payload_record .errors )
1776+ self .filter_subsequent_payloads ( # pragma: no cover
1777+ item_path , async_payload_record
1778+ )
17561779 complete_item = None # pragma: no cover
17571780
17581781 except GraphQLError as error :
17591782 async_payload_record .errors .append (error )
1783+ self .filter_subsequent_payloads (item_path , async_payload_record )
17601784 async_payload_record .add_items (None )
17611785 return async_payload_record
17621786
@@ -1768,6 +1792,7 @@ async def await_completed_items() -> Optional[List[Any]]:
17681792 return [await complete_item ] # type: ignore
17691793 except GraphQLError as error :
17701794 async_payload_record .errors .append (error )
1795+ self .filter_subsequent_payloads (path , async_payload_record )
17711796 return None
17721797
17731798 completed_items = await_completed_items ()
@@ -1786,6 +1811,8 @@ async def execute_stream_iterator_item(
17861811 async_payload_record : StreamRecord ,
17871812 field_path : Path ,
17881813 ) -> Any :
1814+ if iterator in self ._canceled_iterators :
1815+ raise StopAsyncIteration
17891816 try :
17901817 item = await anext (iterator )
17911818 completed_item = self .complete_value (
@@ -1799,12 +1826,13 @@ async def execute_stream_iterator_item(
17991826 )
18001827
18011828 except StopAsyncIteration as raw_error :
1802- async_payload_record .set_ist_completed_iterator ()
1829+ async_payload_record .set_is_completed_iterator ()
18031830 raise StopAsyncIteration from raw_error
18041831
18051832 except Exception as raw_error :
18061833 error = located_error (raw_error , field_nodes , field_path .as_list ())
18071834 handle_field_error (error , item_type , async_payload_record .errors )
1835+ self .filter_subsequent_payloads (field_path , async_payload_record )
18081836
18091837 async def execute_stream_iterator (
18101838 self ,
@@ -1830,30 +1858,50 @@ async def execute_stream_iterator(
18301858 iterator , field_modes , info , item_type , async_payload_record , field_path
18311859 )
18321860
1833- # noinspection PyShadowingNames
1834- async def items (
1835- data : Awaitable [Any ], async_payload_record : StreamRecord
1836- ) -> AwaitableOrValue [Optional [List [Any ]]]:
1837- try :
1838- return [await data ]
1839- except GraphQLError as error :
1840- async_payload_record .errors .append (error )
1841- return None
1842-
18431861 try :
1844- async_payload_record .add_items (
1845- await items (awaitable_data , async_payload_record )
1846- )
1862+ data = await awaitable_data
18471863 except StopAsyncIteration :
18481864 if async_payload_record .errors :
1849- async_payload_record .add_items ([ None ] ) # pragma: no cover
1865+ async_payload_record .add_items (None ) # pragma: no cover
18501866 else :
18511867 del self .subsequent_payloads [async_payload_record ]
18521868 break
1869+ except GraphQLError as error :
1870+ # entire stream has errored and bubbled upwards
1871+ self .filter_subsequent_payloads (path , async_payload_record )
1872+ if iterator : # pragma: no cover else
1873+ with suppress (Exception ):
1874+ await iterator .aclose () # type: ignore
1875+ # running generators cannot be closed since Python 3.8,
1876+ # so we need to remember that this iterator is already canceled
1877+ self ._canceled_iterators .add (iterator )
1878+ async_payload_record .add_items (None )
1879+ async_payload_record .errors .append (error )
1880+ break
1881+
1882+ async_payload_record .add_items ([data ])
18531883
18541884 previous_async_payload_record = async_payload_record
18551885 index += 1
18561886
1887+ def filter_subsequent_payloads (
1888+ self ,
1889+ null_path : Optional [Path ] = None ,
1890+ current_async_record : Optional [AsyncPayloadRecord ] = None ,
1891+ ) -> None :
1892+ null_path_list = null_path .as_list () if null_path else []
1893+ for async_record in list (self .subsequent_payloads ):
1894+ if async_record is current_async_record :
1895+ # don't remove payload from where error originates
1896+ continue
1897+ if async_record .path [: len (null_path_list )] != null_path_list :
1898+ # async_record points to a path unaffected by this payload
1899+ continue
1900+ # async_record path points to nulled error field
1901+ if isinstance (async_record , StreamRecord ) and async_record .iterator :
1902+ self ._canceled_iterators .add (async_record .iterator )
1903+ del self .subsequent_payloads [async_record ]
1904+
18571905 def get_completed_incremental_results (self ) -> List [IncrementalResult ]:
18581906 incremental_results : List [IncrementalResult ] = []
18591907 append_result = incremental_results .append
@@ -2661,12 +2709,16 @@ async def wait(self) -> Optional[Dict[str, Any]]:
26612709 if self .parent_context :
26622710 await self .parent_context .completed .wait ()
26632711 _data = self ._data
2664- data = (
2665- await _data if self ._context .is_awaitable (_data ) else _data # type: ignore
2666- )
2667- self .data = data
2668- await sleep (ASYNC_DELAY ) # always defer completion a little bit
2669- self .completed .set ()
2712+ try :
2713+ data = (
2714+ await _data # type: ignore
2715+ if self ._context .is_awaitable (_data )
2716+ else _data
2717+ )
2718+ finally :
2719+ await sleep (ASYNC_DELAY ) # always defer completion a little bit
2720+ self .data = data
2721+ self .completed .set ()
26702722 return data
26712723
26722724 def add_data (self , data : AwaitableOrValue [Optional [Dict [str , Any ]]]) -> None :
@@ -2728,21 +2780,23 @@ async def wait(self) -> Optional[List[str]]:
27282780 if self .parent_context :
27292781 await self .parent_context .completed .wait ()
27302782 _items = self ._items
2731- items = (
2732- await _items # type: ignore
2733- if self ._context .is_awaitable (_items )
2734- else _items
2735- )
2736- self .items = items
2737- await sleep (ASYNC_DELAY ) # always defer completion a little bit
2738- self .completed .set ()
2783+ try :
2784+ items = (
2785+ await _items # type: ignore
2786+ if self ._context .is_awaitable (_items )
2787+ else _items
2788+ )
2789+ finally :
2790+ await sleep (ASYNC_DELAY ) # always defer completion a little bit
2791+ self .items = items
2792+ self .completed .set ()
27392793 return items
27402794
27412795 def add_items (self , items : AwaitableOrValue [Optional [List [Any ]]]) -> None :
27422796 self ._items = items
27432797 self ._items_added .set ()
27442798
2745- def set_ist_completed_iterator (self ) -> None :
2799+ def set_is_completed_iterator (self ) -> None :
27462800 self .is_completed_iterator = True
27472801 self ._items_added .set ()
27482802
0 commit comments