@@ -293,138 +293,177 @@ async def task_sem() -> int:
293293 assert sem_num == max_async_tasks + 2
294294
295295
296- @pytest .mark .anyio
297- async def test_tasks_chain_without_idler () -> None :
298- """"""
299- broker = InMemoryQueueBroker ()
300-
301- @broker .task
302- async def task_add_one (val : int ) -> int :
303- return val + 1
304-
305- @broker .task
306- async def task_map (vals : List [int ]) -> List [int ]:
307- tasks = [await task_add_one .kiq (val ) for val in vals ]
308- resps_tasks = [asyncio .create_task (t .wait_result (timeout = 1 )) for t in tasks ]
309- resps = await asyncio .gather (* resps_tasks )
310-
311- return [r .return_value for r in resps ]
312-
313- receiver = get_receiver (broker , max_async_tasks = 1 )
314- listen_task = asyncio .create_task (receiver .listen ())
315-
316- task = await task_map .kiq (list (range (0 , 10 )))
317- with pytest .raises (TaskiqResultTimeoutError ):
318- await task .wait_result (timeout = 1 )
319-
320- await broker .shutdown ()
321- await listen_task
322-
323-
324- @pytest .mark .anyio
325- async def test_tasks_chain_with_idler () -> None :
326- """"""
327- broker = InMemoryQueueBroker ()
328-
329- @broker .task
330- async def task_add_one (val : int ) -> int :
331- return val + 1
332-
333- @broker .task
334- async def task_map (vals : List [int ], ctx : Context = Depends ()) -> List [int ]:
335- tasks = [await task_add_one .kiq (val ) for val in vals ]
336- await ctx .sleep (0.5 )
337- resps_tasks = [asyncio .create_task (t .wait_result (timeout = 1 )) for t in tasks ]
338- resps = await asyncio .gather (* resps_tasks )
339- res = [r .return_value for r in resps ]
340- return res
341-
342- receiver = get_receiver (broker , max_async_tasks = 1 , max_sleeping_tasks = 1 )
343- listen_task = asyncio .create_task (receiver .listen ())
344-
345- task = await task_map .kiq (list (range (0 , 10 )))
346- resp = await task .wait_result (timeout = 1 )
347- assert resp .return_value == list (range (1 , 11 ))
348-
349- await broker .shutdown ()
350- await listen_task
351-
352- assert receiver .sem_sleeping ._value == 1 # type: ignore
353- assert receiver .sem ._value == 1 # type: ignore
354-
355-
356- @pytest .mark .anyio
357- async def test_tasks_chain_deep () -> None :
358- """"""
359- broker = InMemoryQueueBroker ()
360-
361- @broker .task
362- async def task_run (depth : int , val : Any , ctx : Context = Depends ()) -> Any :
363- if depth == 0 :
364- return val
365-
366- t = await task_run .kiq (depth - 1 , val )
367- resp = await wait_for_task (t , interval = 0.05 , ctx = ctx )
368- return resp .return_value
369-
370- async def wait_for_task (
371- task : AsyncTaskiqTask [Any ],
372- interval : float ,
373- ctx : Context ,
374- ) -> TaskiqResult [Any ]:
375- while True :
376- resp_task = asyncio .create_task (
377- task .wait_result (interval * 0.4 , timeout = interval ),
378- )
379- await ctx .sleep (interval )
380-
381- try :
382- return await resp_task
383- except TaskiqResultTimeoutError :
384- continue
385-
386- receiver = get_receiver (broker , max_async_tasks = 1 , max_sleeping_tasks = 10 )
387- listen_task = asyncio .create_task (receiver .listen ())
388-
389- task = await task_run .kiq (10 , "hello world!" )
390- resp = await task .wait_result (timeout = 1 )
391- assert resp .return_value == "hello world!"
392-
393- await broker .shutdown ()
394- await listen_task
395-
396- assert receiver .sem_sleeping ._value == 10 # type: ignore
397- assert receiver .sem ._value == 1 # type: ignore
398-
399-
400- @pytest .mark .anyio
401- async def test_tasks_sleep () -> None :
402- """"""
403- broker = InMemoryQueueBroker ()
404-
405- @broker .task
406- async def task_run (ind : int , ctx : Context = Depends ()) -> int :
407- await ctx .sleep (0.1 )
408- return ind
409-
410- receiver = get_receiver (broker , max_async_tasks = 1 , max_sleeping_tasks = 20 )
411- listen_task = asyncio .create_task (receiver .listen ())
412-
413- with anyio .fail_after (1 ):
414- tasks_tasks = [asyncio .create_task (task_run .kiq (ind )) for ind in range (100 )]
415- tasks = await asyncio .gather (* tasks_tasks )
416- resps_tasks = [
417- asyncio .create_task (task .wait_result (timeout = 1 )) for task in tasks
418- ]
419- resps = await asyncio .gather (* resps_tasks )
420- value = [resp .return_value for resp in resps ]
421- assert value == list (range (100 ))
422-
423- await broker .shutdown ()
424- await listen_task
425-
426- assert receiver .sem_sleeping ._value == 20 # type: ignore
427- assert receiver .sem ._value == 1 # type: ignore
296+ class Test_sleeping_tasks :
297+ @pytest .mark .anyio
298+ async def test_max_sleeping_task_arg_error (self ) -> None :
299+ with pytest .raises (ValueError ):
300+ get_receiver (max_sleeping_tasks = - 1 )
301+
302+ @pytest .mark .anyio
303+ async def test_tasks_chain_without_nonblocking_sleep (self ) -> None :
304+ """"""
305+ broker = InMemoryQueueBroker ()
306+
307+ @broker .task
308+ async def task_add_one (val : int ) -> int :
309+ return val + 1
310+
311+ @broker .task
312+ async def task_map (vals : List [int ]) -> List [int ]:
313+ tasks = [await task_add_one .kiq (val ) for val in vals ]
314+ resps_tasks = [asyncio .create_task (t .wait_result (timeout = 1 )) for t in tasks ]
315+ resps = await asyncio .gather (* resps_tasks )
316+
317+ return [r .return_value for r in resps ]
318+
319+ receiver = get_receiver (broker , max_async_tasks = 1 )
320+ listen_task = asyncio .create_task (receiver .listen ())
321+
322+ task = await task_map .kiq (list (range (0 , 10 )))
323+ with pytest .raises (TaskiqResultTimeoutError ):
324+ await task .wait_result (timeout = 1 )
325+
326+ await broker .shutdown ()
327+ await listen_task
328+
329+ @pytest .mark .anyio
330+ async def test_tasks_chain_with_nonblocking_sleep (self ) -> None :
331+ """"""
332+ broker = InMemoryQueueBroker ()
333+
334+ @broker .task
335+ async def task_add_one (val : int ) -> int :
336+ return val + 1
337+
338+ @broker .task
339+ async def task_map (vals : List [int ], ctx : Context = Depends ()) -> List [int ]:
340+ tasks = [await task_add_one .kiq (val ) for val in vals ]
341+ await ctx .sleep (0.5 )
342+ resps_tasks = [asyncio .create_task (t .wait_result (timeout = 1 )) for t in tasks ]
343+ resps = await asyncio .gather (* resps_tasks )
344+ res = [r .return_value for r in resps ]
345+ return res
346+
347+ receiver = get_receiver (broker , max_async_tasks = 1 , max_sleeping_tasks = 1 )
348+ listen_task = asyncio .create_task (receiver .listen ())
349+
350+ task = await task_map .kiq (list (range (0 , 10 )))
351+ resp = await task .wait_result (timeout = 1 )
352+ assert resp .return_value == list (range (1 , 11 ))
353+
354+ await broker .shutdown ()
355+ await listen_task
356+
357+ assert receiver .sem_sleeping ._value == 1 # type: ignore
358+ assert receiver .sem ._value == 1 # type: ignore
359+
360+ @pytest .mark .anyio
361+ async def test_tasks_long_chain (self ) -> None :
362+ """"""
363+ broker = InMemoryQueueBroker ()
364+
365+ @broker .task
366+ async def task_run (depth : int , val : Any , ctx : Context = Depends ()) -> Any :
367+ if depth == 0 :
368+ return val
369+
370+ t = await task_run .kiq (depth - 1 , val )
371+ resp = await wait_for_task (t , interval = 0.05 , ctx = ctx )
372+ return resp .return_value
373+
374+ async def wait_for_task (
375+ task : AsyncTaskiqTask [Any ],
376+ interval : float ,
377+ ctx : Context ,
378+ ) -> TaskiqResult [Any ]:
379+ while True :
380+ resp_task = asyncio .create_task (
381+ task .wait_result (interval * 0.4 , timeout = interval ),
382+ )
383+ await ctx .sleep (interval )
384+
385+ try :
386+ return await resp_task
387+ except TaskiqResultTimeoutError :
388+ continue
389+
390+ receiver = get_receiver (broker , max_async_tasks = 1 , max_sleeping_tasks = 10 )
391+ listen_task = asyncio .create_task (receiver .listen ())
392+
393+ task = await task_run .kiq (10 , "hello world!" )
394+ resp = await task .wait_result (timeout = 1 )
395+ assert resp .return_value == "hello world!"
396+
397+ await broker .shutdown ()
398+ await listen_task
399+
400+ assert receiver .sem_sleeping ._value == 10 # type: ignore
401+ assert receiver .sem ._value == 1 # type: ignore
402+
403+ @pytest .mark .parametrize (
404+ ("max_async_tasks" , "max_sleeping_tasks" ),
405+ [(1 , 20 ), (None , None ), (None , 20 ), (0 , None ), (0 , 20 ), (0 , 0 )],
406+ )
407+ @pytest .mark .anyio
408+ async def test_tasks_sleep (
409+ self ,
410+ max_async_tasks : Any ,
411+ max_sleeping_tasks : Any ,
412+ ) -> None :
413+ """"""
414+ broker = InMemoryQueueBroker ()
415+
416+ @broker .task
417+ async def task_run (ind : int , ctx : Context = Depends ()) -> int :
418+ await ctx .sleep (0.1 )
419+ return ind
420+
421+ receiver = get_receiver (
422+ broker ,
423+ max_async_tasks = max_async_tasks ,
424+ max_sleeping_tasks = max_sleeping_tasks ,
425+ )
426+ listen_task = asyncio .create_task (receiver .listen ())
427+
428+ with anyio .fail_after (1 ):
429+ tasks_tasks = [asyncio .create_task (task_run .kiq (ind )) for ind in range (100 )]
430+ tasks = await asyncio .gather (* tasks_tasks )
431+ resps_tasks = [
432+ asyncio .create_task (task .wait_result (timeout = 1 )) for task in tasks
433+ ]
434+ resps = await asyncio .gather (* resps_tasks )
435+ value = [resp .return_value for resp in resps ]
436+ assert value == list (range (100 ))
437+
438+ await broker .shutdown ()
439+ await listen_task
440+
441+ if max_sleeping_tasks is not None and max_sleeping_tasks > 0 :
442+ assert receiver .sem_sleeping ._value == max_sleeping_tasks # type: ignore
443+
444+ if max_async_tasks is not None and max_async_tasks > 0 :
445+ assert receiver .sem ._value == 1 # type: ignore
446+
447+ @pytest .mark .anyio
448+ async def test_max_sleeping_task_arg_none (self ) -> None :
449+ """"""
450+ broker = InMemoryQueueBroker ()
451+
452+ @broker .task
453+ async def task_run (ind : int , ctx : Context = Depends ()) -> int :
454+ await ctx .sleep (0.1 )
455+ return ind
456+
457+ receiver = get_receiver (broker , max_async_tasks = 1 , max_sleeping_tasks = None )
458+ listen_task = asyncio .create_task (receiver .listen ()) # type: ignore
459+
460+ with pytest .raises (TaskiqResultTimeoutError ):
461+ tasks_tasks = [asyncio .create_task (task_run .kiq (ind )) for ind in range (100 )]
462+ tasks = await asyncio .gather (* tasks_tasks )
463+ resps_tasks = [
464+ asyncio .create_task (task .wait_result (timeout = 1 )) for task in tasks
465+ ]
466+ await asyncio .gather (* resps_tasks )
428467
429468
430469@pytest .mark .anyio
0 commit comments