1- from inspect import isawaitable
1+ from asyncio import Event , ensure_future , wait
2+ from concurrent .futures import FIRST_COMPLETED
3+ from inspect import isasyncgen , isawaitable
24from typing import AsyncIterable , Callable
35
46__all__ = ['MapAsyncIterator' ]
@@ -19,35 +21,62 @@ def __init__(self, iterable: AsyncIterable, callback: Callable,
1921 self .iterator = iterable .__aiter__ ()
2022 self .callback = callback
2123 self .reject_callback = reject_callback
22- self .stop = False
24+ self ._close_event = Event ()
25+
26+ @property
27+ def closed (self ) -> bool :
28+ return self ._close_event .is_set ()
29+
30+ @closed .setter
31+ def closed (self , value : bool ) -> None :
32+ if value :
33+ self ._close_event .set ()
34+ else :
35+ self ._close_event .clear ()
2336
2437 def __aiter__ (self ):
2538 return self
2639
2740 async def __anext__ (self ):
28- if self .stop :
41+ if self .closed :
42+ if not isasyncgen (self .iterator ):
43+ raise StopAsyncIteration
44+ result = await self .iterator .__anext__ ()
45+ return self .callback (result )
46+
47+ _close = ensure_future (self ._close_event .wait ())
48+ _next = ensure_future (self .iterator .__anext__ ())
49+ done , pending = await wait (
50+ [_close , _next ],
51+ return_when = FIRST_COMPLETED ,
52+ )
53+
54+ for task in pending :
55+ task .cancel ()
56+
57+ if _close .done ():
2958 raise StopAsyncIteration
30- try :
31- value = await self . iterator . __anext__ ()
32- except Exception as error :
33- if not self . reject_callback or isinstance ( error , (
34- StopAsyncIteration , GeneratorExit )):
35- raise
36- result = self . reject_callback ( error )
37- else :
38- result = self . callback ( value )
39- if isawaitable ( result ):
40- result = await result
41- return result
59+
60+ if _next . done ():
61+ error = _next . exception ()
62+ if error :
63+ if not self . reject_callback or isinstance ( error , (
64+ StopAsyncIteration , GeneratorExit )):
65+ raise error
66+ result = self . reject_callback ( error )
67+ else :
68+ result = self . callback ( _next . result ())
69+
70+ return ( await result ) if isawaitable ( result ) else result
4271
4372 async def athrow (self , type_ , value = None , traceback = None ):
44- if self .stop :
73+ if self .closed :
4574 return
4675 athrow = getattr (self .iterator , 'athrow' , None )
4776 if athrow :
4877 await athrow (type_ , value , traceback )
4978 else :
50- self .stop = True
79+ self .closed = True
5180 if value is None :
5281 if traceback is None :
5382 raise type_
@@ -57,13 +86,12 @@ async def athrow(self, type_, value=None, traceback=None):
5786 raise value
5887
5988 async def aclose (self ):
60- if self .stop :
89+ if self .closed :
6190 return
6291 aclose = getattr (self .iterator , 'aclose' , None )
6392 if aclose :
6493 try :
6594 await aclose ()
6695 except RuntimeError :
6796 pass
68- else :
69- self .stop = True
97+ self .closed = True
0 commit comments