11import asyncio
2+ from asyncio import (
3+ Event ,
4+ )
5+ from contextlib import contextmanager
6+ import functools
27import pytest
8+ import random
39
10+ from cancel_token import CancelToken , OperationCancelled
411from eth_utils import ValidationError
12+ from hypothesis import (
13+ example ,
14+ given ,
15+ strategies as st ,
16+ )
517
618from trinity .utils .datastructures import TaskQueue
719
20+ DEFAULT_TIMEOUT = 0.05
21+
822
9- async def wait (coro , timeout = 0.05 ):
23+ async def wait (coro , timeout = DEFAULT_TIMEOUT ):
1024 return await asyncio .wait_for (coro , timeout = timeout )
1125
1226
27+ @contextmanager
28+ def trap_operation_cancelled ():
29+ try :
30+ yield
31+ except OperationCancelled :
32+ pass
33+
34+
35+ def run_in_event_loop (async_func ):
36+ @functools .wraps (async_func )
37+ def wrapped (operations , queue_size , add_size , get_size , event_loop ):
38+ event_loop .run_until_complete (asyncio .ensure_future (
39+ async_func (operations , queue_size , add_size , get_size , event_loop ),
40+ loop = event_loop ,
41+ ))
42+ return wrapped
43+
44+
45+ @given (
46+ operations = st .lists (
47+ elements = st .tuples (st .integers (min_value = 0 , max_value = 5 ), st .booleans ()),
48+ min_size = 10 ,
49+ max_size = 30 ,
50+ ),
51+ queue_size = st .integers (min_value = 1 , max_value = 20 ),
52+ add_size = st .integers (min_value = 1 , max_value = 20 ),
53+ get_size = st .integers (min_value = 1 , max_value = 20 ),
54+ )
55+ @example (
56+ # try having two adders alternate a couple times quickly
57+ operations = [(0 , False ), (1 , False ), (0 , False ), (1 , True ), (2 , False ), (2 , False ), (2 , False )],
58+ queue_size = 5 ,
59+ add_size = 2 ,
60+ get_size = 5 ,
61+ )
62+ @run_in_event_loop
63+ async def test_no_asyncio_exception_leaks (operations , queue_size , add_size , get_size , event_loop ):
64+ """
65+ This could be made much more general, at the cost of simplicity.
66+ For now, this mimics real usage enough to hopefully catch the big issues.
67+
68+ Some examples for more generality:
69+
70+ - different get sizes on each call
71+ - complete varying amounts of tasks at each call
72+ """
73+
74+ async def getter (queue , num_tasks , get_event , complete_event , cancel_token ):
75+ with trap_operation_cancelled ():
76+ # wait to run the get
77+ await cancel_token .cancellable_wait (get_event .wait ())
78+
79+ batch , tasks = await cancel_token .cancellable_wait (
80+ queue .get (num_tasks )
81+ )
82+ get_event .clear ()
83+
84+ # wait to run the completion
85+ await cancel_token .cancellable_wait (complete_event .wait ())
86+
87+ queue .complete (batch , tasks )
88+ complete_event .clear ()
89+
90+ async def adder (queue , add_size , add_event , cancel_token ):
91+ with trap_operation_cancelled ():
92+ # wait to run the add
93+ await cancel_token .cancellable_wait (add_event .wait ())
94+
95+ await cancel_token .cancellable_wait (
96+ queue .add (tuple (random .randint (0 , 2 ** 32 ) for _ in range (add_size )))
97+ )
98+ add_event .clear ()
99+
100+ async def operation_order (operations , events , cancel_token ):
101+ for operation_id , pause in operations :
102+ events [operation_id ].set ()
103+ if pause :
104+ await asyncio .sleep (0 )
105+
106+ await asyncio .sleep (0 )
107+ cancel_token .trigger ()
108+
109+ q = TaskQueue (queue_size )
110+ events = tuple (Event () for _ in range (6 ))
111+ add_event , add2_event , get_event , get2_event , complete_event , complete2_event = events
112+ cancel_token = CancelToken ('end test' )
113+
114+ done , pending = await asyncio .wait ([
115+ getter (q , get_size , get_event , complete_event , cancel_token ),
116+ getter (q , get_size , get2_event , complete2_event , cancel_token ),
117+ adder (q , add_size , add_event , cancel_token ),
118+ adder (q , add_size , add2_event , cancel_token ),
119+ operation_order (operations , events , cancel_token ),
120+ ], return_when = asyncio .FIRST_EXCEPTION )
121+
122+ for task in done :
123+ exc = task .exception ()
124+ if exc :
125+ raise exc
126+
127+ assert not pending
128+
129+
13130@pytest .mark .asyncio
14131async def test_queue_size_reset_after_complete ():
15132 q = TaskQueue (maxsize = 2 )
@@ -63,7 +180,7 @@ async def test_default_priority_order():
63180
64181@pytest .mark .asyncio
65182async def test_custom_priority_order ():
66- q = TaskQueue (maxsize = 4 , order_fn = lambda x : 0 - x )
183+ q = TaskQueue (maxsize = 4 , order_fn = lambda x : 0 - x )
67184
68185 await wait (q .add ((2 , 1 , 3 )))
69186 (batch , tasks ) = await wait (q .get ())
@@ -108,6 +225,25 @@ async def test_wait_empty_queue():
108225 assert False , "should not return from get() when nothing is available on queue"
109226
110227
228+ @pytest .mark .asyncio
229+ async def test_cannot_complete_batch_with_wrong_task ():
230+ q = TaskQueue ()
231+
232+ await wait (q .add ((1 , 2 )))
233+
234+ batch , tasks = await wait (q .get ())
235+
236+ # cannot complete a valid task with a task it wasn't given
237+ with pytest .raises (ValidationError ):
238+ q .complete (batch , (3 , 4 ))
239+
240+ # partially invalid completion calls leave the valid task in an incomplete state
241+ with pytest .raises (ValidationError ):
242+ q .complete (batch , (1 , 3 ))
243+
244+ assert 1 in q
245+
246+
111247@pytest .mark .asyncio
112248async def test_cannot_complete_batch_unless_pending ():
113249 q = TaskQueue ()
@@ -156,10 +292,9 @@ async def test_two_pending_adds_one_release():
156292 assert len (tasks ) in {0 , 1 }
157293
158294 if len (tasks ) == 1 :
159- batch2 , tasks2 = await wait (q .get ())
295+ _ , tasks2 = await wait (q .get ())
160296 all_tasks = tuple (sorted (tasks + tasks2 ))
161297 elif len (tasks ) == 2 :
162- batch2 = None
163298 all_tasks = tasks
164299
165300 assert all_tasks == (0 , 3 )
@@ -186,12 +321,20 @@ async def test_queue_get_cap(start_tasks, get_max, expected, remainder):
186321 assert tasks == expected
187322
188323 if remainder :
189- batch2 , tasks2 = await wait (q .get ())
324+ _ , tasks2 = await wait (q .get ())
190325 assert tasks2 == remainder
191326 else :
192327 try :
193- batch2 , tasks2 = await wait (q .get ())
328+ _ , tasks2 = await wait (q .get ())
194329 except asyncio .TimeoutError :
195330 pass
196331 else :
197332 assert False , f"No more tasks to get, but got { tasks2 !r} "
333+
334+
335+ @pytest .mark .asyncio
336+ async def test_cannot_readd_same_task ():
337+ q = TaskQueue ()
338+ await q .add ((1 , 2 ))
339+ with pytest .raises (ValidationError ):
340+ await q .add ((2 ,))
0 commit comments