2424import neo4j
2525
2626from ... import env
27+ from ..._async_compat import mark_async_test
2728
2829
2930# TODO: Python 3.9: when support gets dropped, remove this mark
@@ -44,7 +45,7 @@ def test_can_create_async_driver_outside_of_loop(uri, auth):
4445
4546 async def return_1 (tx : neo4j .AsyncManagedTransaction ) -> None :
4647 nonlocal counter , was_full
47- res = await tx .run ("RETURN 1 " )
48+ res = await tx .run ("UNWIND range(1, 10000) AS x RETURN x " )
4849
4950 counter += 1
5051 while not was_full and counter < pool_size :
@@ -86,3 +87,55 @@ async def run(driver_: neo4j.AsyncDriver):
8687 loop .run_until_complete (coro )
8788 finally :
8889 loop .close ()
90+
91+
92+ @mark_async_test
93+ async def test_cancel_driver_close (uri , auth ):
94+ class Signal :
95+ queried = False
96+ released = False
97+
98+ async def fill_pool (driver_ : neo4j .AsyncDriver , n = 10 ):
99+ signals = [Signal () for _ in range (n )]
100+ await asyncio .gather (
101+ * (handle_session (driver_ .session (), signals [i ]) for i in range (n )),
102+ handle_signals (signals ),
103+ return_exceptions = True ,
104+ )
105+
106+ async def handle_signals (signals ):
107+ while any (not signal .queried for signal in signals ):
108+ await asyncio .sleep (0.001 )
109+ await asyncio .sleep (0.1 )
110+ for signal in signals :
111+ signal .released = True
112+
113+ async def handle_session (session , signal ):
114+ async with session :
115+ await session .execute_read (work , signal )
116+
117+ async def work (tx : neo4j .AsyncManagedTransaction , signal : Signal ) -> None :
118+ res = await tx .run ("UNWIND range(1, 10000) AS x RETURN x" )
119+ signal .queried = True
120+ while not signal .released :
121+ await asyncio .sleep (0.001 )
122+ await res .consume ()
123+
124+ def connection_count (driver_ ):
125+ return sum (len (v ) for v in driver_ ._pool .connections .values ())
126+
127+ driver = neo4j .AsyncGraphDatabase .driver (uri , auth = auth )
128+ await fill_pool (driver )
129+ # sanity check, there should be some connections
130+ assert connection_count (driver ) >= 10
131+
132+ # start the close and give it some event loop iterations to kick off
133+ fut = asyncio .ensure_future (driver .close ())
134+ await asyncio .sleep (0 )
135+
136+ # cancel in the middle of closing connections
137+ fut .cancel ()
138+ # give the driver a chance to close connections forcefully
139+ await asyncio .sleep (0 )
140+ # driver should be marked as closed to not emmit a ResourceWarning later
141+ assert driver ._closed == True
0 commit comments