@@ -13,7 +13,40 @@ async def anext(iterable):
1313
1414def describe_map_async_iterator ():
1515 @mark .asyncio
16- async def maps_over_async_values ():
16+ async def maps_over_async_generator ():
17+ async def source ():
18+ yield 1
19+ yield 2
20+ yield 3
21+
22+ doubles = MapAsyncIterator (source (), lambda x : x + x )
23+
24+ assert await anext (doubles ) == 2
25+ assert await anext (doubles ) == 4
26+ assert await anext (doubles ) == 6
27+ with raises (StopAsyncIteration ):
28+ assert await anext (doubles )
29+
30+ @mark .asyncio
31+ async def maps_over_async_iterator ():
32+ items = [1 , 2 , 3 ]
33+
34+ class Iterator :
35+ def __aiter__ (self ):
36+ return self
37+
38+ async def __anext__ (self ):
39+ try :
40+ return items .pop (0 )
41+ except IndexError :
42+ raise StopAsyncIteration
43+
44+ doubles = MapAsyncIterator (Iterator (), lambda x : x + x )
45+
46+ assert [value async for value in doubles ] == [2 , 4 , 6 ]
47+
48+ @mark .asyncio
49+ async def compatible_with_async_for ():
1750 async def source ():
1851 yield 1
1952 yield 2
@@ -38,11 +71,11 @@ async def double(x):
3871 assert [value async for value in doubles ] == [2 , 4 , 6 ]
3972
4073 @mark .asyncio
41- async def allows_returning_early_from_async_values ():
74+ async def allows_returning_early_from_mapped_async_generator ():
4275 async def source ():
4376 yield 1
4477 yield 2
45- yield 3
78+ yield 3 # pragma: no cover
4679
4780 doubles = MapAsyncIterator (source (), lambda x : x + x )
4881
@@ -58,13 +91,41 @@ async def source():
5891 with raises (StopAsyncIteration ):
5992 await anext (doubles )
6093
94+ @mark .asyncio
95+ async def allows_returning_early_from_mapped_async_iterator ():
96+ items = [1 , 2 , 3 ]
97+
98+ class Iterator :
99+ def __aiter__ (self ):
100+ return self
101+
102+ async def __anext__ (self ):
103+ try :
104+ return items .pop (0 )
105+ except IndexError : # pragma: no cover
106+ raise StopAsyncIteration
107+
108+ doubles = MapAsyncIterator (Iterator (), lambda x : x + x )
109+
110+ assert await anext (doubles ) == 2
111+ assert await anext (doubles ) == 4
112+
113+ # Early return
114+ await doubles .aclose ()
115+
116+ # Subsequent next calls
117+ with raises (StopAsyncIteration ):
118+ await anext (doubles )
119+ with raises (StopAsyncIteration ):
120+ await anext (doubles )
121+
61122 @mark .asyncio
62123 async def passes_through_early_return_from_async_values ():
63124 async def source ():
64125 try :
65126 yield 1
66127 yield 2
67- yield 3
128+ yield 3 # pragma: no cover
68129 finally :
69130 yield "Done"
70131 yield "Last"
@@ -83,13 +144,20 @@ async def source():
83144 assert await anext (doubles )
84145
85146 @mark .asyncio
86- async def allows_throwing_errors_through_async_generators ():
87- async def source ():
88- yield 1
89- yield 2
90- yield 3
147+ async def allows_throwing_errors_through_async_iterators ():
148+ items = [1 , 2 , 3 ]
91149
92- doubles = MapAsyncIterator (source (), lambda x : x + x )
150+ class Iterator :
151+ def __aiter__ (self ):
152+ return self
153+
154+ async def __anext__ (self ):
155+ try :
156+ return items .pop (0 )
157+ except IndexError : # pragma: no cover
158+ raise StopAsyncIteration
159+
160+ doubles = MapAsyncIterator (Iterator (), lambda x : x + x )
93161
94162 assert await anext (doubles ) == 2
95163 assert await anext (doubles ) == 4
@@ -111,7 +179,7 @@ async def source():
111179 try :
112180 yield 1
113181 yield 2
114- yield 3
182+ yield 3 # pragma: no cover
115183 except Exception as e :
116184 yield e
117185
@@ -249,8 +317,8 @@ async def stops_async_iteration_on_close():
249317 async def source ():
250318 yield 1
251319 await Event ().wait () # Block forever
252- yield 2
253- yield 3
320+ yield 2 # pragma: no cover
321+ yield 3 # pragma: no cover
254322
255323 singles = source ()
256324 doubles = MapAsyncIterator (singles , lambda x : x * 2 )
@@ -271,3 +339,53 @@ async def source():
271339
272340 with raises (StopAsyncIteration ):
273341 await anext (singles )
342+
343+ @mark .asyncio
344+ async def can_unset_closed_state_of_async_iterator ():
345+ items = [1 , 2 , 3 ]
346+
347+ class Iterator :
348+ def __init__ (self ):
349+ self .is_closed = False
350+
351+ def __aiter__ (self ):
352+ return self
353+
354+ async def __anext__ (self ):
355+ if self .is_closed :
356+ raise StopAsyncIteration
357+ try :
358+ return items .pop (0 )
359+ except IndexError :
360+ raise StopAsyncIteration
361+
362+ async def aclose (self ):
363+ self .is_closed = True
364+
365+ iterator = Iterator ()
366+ doubles = MapAsyncIterator (iterator , lambda x : x + x )
367+
368+ assert await anext (doubles ) == 2
369+ assert await anext (doubles ) == 4
370+ assert not iterator .is_closed
371+ await doubles .aclose ()
372+ assert iterator .is_closed
373+ with raises (StopAsyncIteration ):
374+ await anext (iterator )
375+ with raises (StopAsyncIteration ):
376+ await anext (doubles )
377+ assert doubles .is_closed
378+
379+ iterator .is_closed = False
380+ doubles .is_closed = False
381+ assert not doubles .is_closed
382+
383+ assert await anext (doubles ) == 6
384+ assert not doubles .is_closed
385+ assert not iterator .is_closed
386+ with raises (StopAsyncIteration ):
387+ await anext (iterator )
388+ with raises (StopAsyncIteration ):
389+ await anext (doubles )
390+ assert not doubles .is_closed
391+ assert not iterator .is_closed
0 commit comments