@@ -12,104 +12,319 @@ async def anext(iterator):
1212 return await iterator .__anext__ ()
1313
1414
15- async def map_doubles (x : int ) -> int :
15+ async def double (x : int ) -> int :
16+ """Test callback that doubles the input value."""
1617 return x + x
1718
1819
20+ async def throw (_x : int ) -> int :
21+ """Test callback that raises a RuntimeError."""
22+ raise RuntimeError ("Ouch" )
23+
24+
1925def describe_map_async_iterable ():
2026 @mark .asyncio
21- async def inner_is_closed_when_outer_is_closed ():
22- class Inner :
23- def __init__ (self ):
24- self .closed = False
27+ async def maps_over_async_generator ():
28+ async def source ():
29+ yield 1
30+ yield 2
31+ yield 3
2532
26- async def aclose (self ):
27- self .closed = True
33+ doubles = map_async_iterable (source (), double )
34+
35+ assert await anext (doubles ) == 2
36+ assert await anext (doubles ) == 4
37+ assert await anext (doubles ) == 6
38+ with raises (StopAsyncIteration ):
39+ assert await anext (doubles )
40+
41+ @mark .asyncio
42+ async def maps_over_async_iterable ():
43+ items = [1 , 2 , 3 ]
44+
45+ class Iterable :
46+ def __aiter__ (self ):
47+ return self
48+
49+ async def __anext__ (self ):
50+ try :
51+ return items .pop (0 )
52+ except IndexError :
53+ raise StopAsyncIteration
54+
55+ doubles = map_async_iterable (Iterable (), double )
56+
57+ values = [value async for value in doubles ]
58+
59+ assert not items
60+ assert values == [2 , 4 , 6 ]
61+
62+ @mark .asyncio
63+ async def compatible_with_async_for ():
64+ async def source ():
65+ yield 1
66+ yield 2
67+ yield 3
68+
69+ doubles = map_async_iterable (source (), double )
70+
71+ values = [value async for value in doubles ]
72+
73+ assert values == [2 , 4 , 6 ]
74+
75+ @mark .asyncio
76+ async def allows_returning_early_from_mapped_async_generator ():
77+ async def source ():
78+ yield 1
79+ yield 2
80+ yield 3 # pragma: no cover
2881
82+ doubles = map_async_iterable (source (), double )
83+
84+ assert await anext (doubles ) == 2
85+ assert await anext (doubles ) == 4
86+
87+ # Early return
88+ await doubles .aclose ()
89+
90+ # Subsequent next calls
91+ with raises (StopAsyncIteration ):
92+ await anext (doubles )
93+ with raises (StopAsyncIteration ):
94+ await anext (doubles )
95+
96+ @mark .asyncio
97+ async def allows_returning_early_from_mapped_async_iterable ():
98+ items = [1 , 2 , 3 ]
99+
100+ class Iterable :
101+ def __aiter__ (self ):
102+ return self
103+
104+ async def __anext__ (self ):
105+ try :
106+ return items .pop (0 )
107+ except IndexError : # pragma: no cover
108+ raise StopAsyncIteration
109+
110+ doubles = map_async_iterable (Iterable (), double )
111+
112+ assert await anext (doubles ) == 2
113+ assert await anext (doubles ) == 4
114+
115+ # Early return
116+ await doubles .aclose ()
117+
118+ # Subsequent next calls
119+ with raises (StopAsyncIteration ):
120+ await anext (doubles )
121+ with raises (StopAsyncIteration ):
122+ await anext (doubles )
123+
124+ @mark .asyncio
125+ async def allows_throwing_errors_through_async_iterable ():
126+ items = [1 , 2 , 3 ]
127+
128+ class Iterable :
129+ def __aiter__ (self ):
130+ return self
131+
132+ async def __anext__ (self ):
133+ try :
134+ return items .pop (0 )
135+ except IndexError : # pragma: no cover
136+ raise StopAsyncIteration
137+
138+ doubles = map_async_iterable (Iterable (), double )
139+
140+ assert await anext (doubles ) == 2
141+ assert await anext (doubles ) == 4
142+
143+ # Throw error
144+ message = "allows throwing errors when mapping async iterable"
145+ with raises (RuntimeError ) as exc_info :
146+ await doubles .athrow (RuntimeError (message ))
147+
148+ assert str (exc_info .value ) == message
149+
150+ with raises (StopAsyncIteration ):
151+ await anext (doubles )
152+ with raises (StopAsyncIteration ):
153+ await anext (doubles )
154+
155+ @mark .asyncio
156+ async def allows_throwing_errors_with_values_through_async_iterables ():
157+ class Iterable :
158+ def __aiter__ (self ):
159+ return self
160+
161+ async def __anext__ (self ):
162+ return 1
163+
164+ one = map_async_iterable (Iterable (), double )
165+
166+ assert await anext (one ) == 2
167+
168+ # Throw error with value passed separately
169+ try :
170+ raise RuntimeError ("Ouch" )
171+ except RuntimeError as error :
172+ with raises (RuntimeError , match = "Ouch" ) as exc_info :
173+ await one .athrow (error .__class__ , error )
174+
175+ assert exc_info .value is error
176+ assert exc_info .tb is error .__traceback__
177+
178+ with raises (StopAsyncIteration ):
179+ await anext (one )
180+
181+ @mark .asyncio
182+ async def allows_throwing_errors_with_traceback_through_async_iterables ():
183+ class Iterable :
29184 def __aiter__ (self ):
30185 return self
31186
32187 async def __anext__ (self ):
33188 return 1
34189
35- inner = Inner ()
36- outer = map_async_iterable (inner , map_doubles )
37- iterator = outer .__aiter__ ()
38- assert await anext (iterator ) == 2
39- assert not inner .closed
40- await outer .aclose ()
41- assert inner .closed
190+ one = map_async_iterable (Iterable (), double )
191+
192+ assert await anext (one ) == 2
193+
194+ # Throw error with traceback passed separately
195+ try :
196+ raise RuntimeError ("Ouch" )
197+ except RuntimeError as error :
198+ with raises (RuntimeError ) as exc_info :
199+ await one .athrow (error .__class__ , None , error .__traceback__ )
200+
201+ assert exc_info .tb and error .__traceback__
202+ assert exc_info .tb .tb_frame is error .__traceback__ .tb_frame
203+
204+ with raises (StopAsyncIteration ):
205+ await anext (one )
42206
43207 @mark .asyncio
44- async def inner_is_closed_on_callback_error ():
45- class Inner :
208+ async def does_not_map_over_thrown_errors ():
209+ async def source ():
210+ yield 1
211+ raise RuntimeError ("Goodbye" )
212+
213+ doubles = map_async_iterable (source (), double )
214+
215+ assert await anext (doubles ) == 2
216+
217+ with raises (RuntimeError ) as exc_info :
218+ await anext (doubles )
219+
220+ assert str (exc_info .value ) == "Goodbye"
221+
222+ @mark .asyncio
223+ async def does_not_map_over_externally_thrown_errors ():
224+ async def source ():
225+ yield 1
226+
227+ doubles = map_async_iterable (source (), double )
228+
229+ assert await anext (doubles ) == 2
230+
231+ with raises (RuntimeError ) as exc_info :
232+ await doubles .athrow (RuntimeError ("Goodbye" ))
233+
234+ assert str (exc_info .value ) == "Goodbye"
235+
236+ @mark .asyncio
237+ async def iterable_is_closed_when_mapped_iterable_is_closed ():
238+ class Iterable :
46239 def __init__ (self ):
47240 self .closed = False
48241
242+ def __aiter__ (self ):
243+ return self
244+
245+ async def __anext__ (self ):
246+ return 1
247+
49248 async def aclose (self ):
50249 self .closed = True
51250
251+ iterable = Iterable ()
252+ doubles = map_async_iterable (iterable , double )
253+ assert await anext (doubles ) == 2
254+ assert not iterable .closed
255+ await doubles .aclose ()
256+ assert iterable .closed
257+ with raises (StopAsyncIteration ):
258+ await anext (doubles )
259+
260+ @mark .asyncio
261+ async def iterable_is_closed_on_callback_error ():
262+ class Iterable :
263+ def __init__ (self ):
264+ self .closed = False
265+
52266 def __aiter__ (self ):
53267 return self
54268
55269 async def __anext__ (self ):
56270 return 1
57271
58- async def callback ( v ):
59- raise RuntimeError ()
272+ async def aclose ( self ):
273+ self . closed = True
60274
61- inner = Inner ()
62- outer = map_async_iterable (inner , callback )
63- with raises (RuntimeError ):
64- await anext (outer )
65- assert inner .closed
275+ iterable = Iterable ()
276+ doubles = map_async_iterable (iterable , throw )
277+ with raises (RuntimeError , match = "Ouch" ):
278+ await anext (doubles )
279+ assert iterable .closed
280+ with raises (StopAsyncIteration ):
281+ await anext (doubles )
66282
67283 @mark .asyncio
68- async def test_inner_exits_on_callback_error ():
69- inner_exit = False
284+ async def iterable_exits_on_callback_error ():
285+ exited = False
70286
71- async def inner ():
72- nonlocal inner_exit
287+ async def iterable ():
288+ nonlocal exited
73289 try :
74290 while True :
75291 yield 1
76292 except GeneratorExit :
77- inner_exit = True
293+ exited = True
78294
79- async def callback (v ):
80- raise RuntimeError
81-
82- outer = map_async_iterable (inner (), callback )
83- with raises (RuntimeError ):
84- await anext (outer )
85- assert inner_exit
295+ doubles = map_async_iterable (iterable (), throw )
296+ with raises (RuntimeError , match = "Ouch" ):
297+ await anext (doubles )
298+ assert exited
299+ with raises (StopAsyncIteration ):
300+ await anext (doubles )
86301
87302 @mark .asyncio
88- async def inner_has_no_close_method_when_outer_is_closed ():
89- class Inner :
303+ async def mapped_iterable_is_closed_when_iterable_cannot_be_closed ():
304+ class Iterable :
90305 def __aiter__ (self ):
91306 return self
92307
93308 async def __anext__ (self ):
94309 return 1
95310
96- outer = map_async_iterable (Inner (), map_doubles )
97- iterator = outer .__aiter__ ()
98- assert await anext (iterator ) == 2
99- await outer .aclose ()
311+ doubles = map_async_iterable (Iterable (), double )
312+ assert await anext (doubles ) == 2
313+ await doubles .aclose ()
314+ with raises (StopAsyncIteration ):
315+ await anext (doubles )
100316
101317 @mark .asyncio
102- async def inner_has_no_close_method_on_callback_error ():
103- class Inner :
318+ async def ignores_that_iterable_cannot_be_closed_on_callback_error ():
319+ class Iterable :
104320 def __aiter__ (self ):
105321 return self
106322
107323 async def __anext__ (self ):
108324 return 1
109325
110- async def callback (v ):
111- raise RuntimeError ()
112-
113- outer = map_async_iterable (Inner (), callback )
114- with raises (RuntimeError ):
115- await anext (outer )
326+ doubles = map_async_iterable (Iterable (), throw )
327+ with raises (RuntimeError , match = "Ouch" ):
328+ await anext (doubles )
329+ with raises (StopAsyncIteration ):
330+ await anext (doubles )
0 commit comments