2525
2626from ..work import FakeConnection
2727
28- from neo4j import READ_ACCESS
28+ from neo4j import (
29+ READ_ACCESS ,
30+ WRITE_ACCESS ,
31+ )
2932from neo4j .addressing import ResolvedAddress
3033from neo4j .conf import (
3134 PoolConfig ,
3538from neo4j .io import Neo4jPool
3639
3740
41+ ROUTER_ADDRESS = ResolvedAddress (("1.2.3.1" , 9001 ), host_name = "host" )
42+ READER_ADDRESS = ResolvedAddress (("1.2.3.1" , 9002 ), host_name = "host" )
43+ WRITER_ADDRESS = ResolvedAddress (("1.2.3.1" , 9003 ), host_name = "host" )
44+
45+
3846@pytest .fixture ()
3947def opener ():
40- def open_ (* _ , ** __ ):
48+ def open_ (addr , timeout ):
4149 connection = FakeConnection ()
50+ connection .addr = addr
51+ connection .timeout = timeout
4252 route_mock = Mock ()
4353 route_mock .return_value = [{
4454 "ttl" : 1000 ,
4555 "servers" : [
46- {"addresses" : ["1.2.3.1:9001" ], "role" : "ROUTE" },
47- {
48- "addresses" : ["1.2.3.10:9010" , "1.2.3.11:9011" ],
49- "role" : "READ"
50- },
51- {
52- "addresses" : ["1.2.3.20:9020" , "1.2.3.21:9021" ],
53- "role" : "WRITE"
54- },
56+ {"addresses" : [str (ROUTER_ADDRESS )], "role" : "ROUTE" },
57+ {"addresses" : [str (READER_ADDRESS )], "role" : "READ" },
58+ {"addresses" : [str (WRITER_ADDRESS )], "role" : "WRITE" },
5559 ],
5660 }]
5761 connection .attach_mock (route_mock , "route" )
@@ -65,8 +69,7 @@ def open_(*_, **__):
6569
6670
6771def test_acquires_new_routing_table_if_deleted (opener ):
68- address = ResolvedAddress (("1.2.3.1" , 9001 ), host_name = "host" )
69- pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), address )
72+ pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), ROUTER_ADDRESS )
7073 cx = pool .acquire (READ_ACCESS , 30 , "test_db" , None )
7174 pool .release (cx )
7275 assert pool .routing_tables .get ("test_db" )
@@ -79,8 +82,7 @@ def test_acquires_new_routing_table_if_deleted(opener):
7982
8083
8184def test_acquires_new_routing_table_if_stale (opener ):
82- address = ResolvedAddress (("1.2.3.1" , 9001 ), host_name = "host" )
83- pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), address )
85+ pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), ROUTER_ADDRESS )
8486 cx = pool .acquire (READ_ACCESS , 30 , "test_db" , None )
8587 pool .release (cx )
8688 assert pool .routing_tables .get ("test_db" )
@@ -94,8 +96,7 @@ def test_acquires_new_routing_table_if_stale(opener):
9496
9597
9698def test_removes_old_routing_table (opener ):
97- address = ResolvedAddress (("1.2.3.1" , 9001 ), host_name = "host" )
98- pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), address )
99+ pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), ROUTER_ADDRESS )
99100 cx = pool .acquire (READ_ACCESS , 30 , "test_db1" , None )
100101 pool .release (cx )
101102 assert pool .routing_tables .get ("test_db1" )
@@ -113,3 +114,50 @@ def test_removes_old_routing_table(opener):
113114 assert pool .routing_tables ["test_db1" ].last_updated_time > old_value
114115 assert "test_db2" not in pool .routing_tables
115116
117+
118+ @pytest .mark .parametrize ("type_" , ("r" , "w" ))
119+ def test_chooses_right_connection_type (opener , type_ ):
120+ pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), ROUTER_ADDRESS )
121+ cx1 = pool .acquire (READ_ACCESS if type_ == "r" else WRITE_ACCESS ,
122+ 30 , "test_db" , None )
123+ pool .release (cx1 )
124+ if type_ == "r" :
125+ assert cx1 .addr == READER_ADDRESS
126+ else :
127+ assert cx1 .addr == WRITER_ADDRESS
128+
129+
130+ def test_reuses_connection (opener ):
131+ pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), ROUTER_ADDRESS )
132+ cx1 = pool .acquire (READ_ACCESS , 30 , "test_db" , None )
133+ pool .release (cx1 )
134+ cx2 = pool .acquire (READ_ACCESS , 30 , "test_db" , None )
135+ assert cx1 is cx2
136+
137+
138+ @pytest .mark .parametrize ("break_on_close" , (True , False ))
139+ def test_closes_stale_connections (opener , break_on_close ):
140+ def break_connection ():
141+ pool .deactivate (cx1 .addr )
142+
143+ if cx_close_mock_side_effect :
144+ cx_close_mock_side_effect ()
145+
146+ pool = Neo4jPool (opener , PoolConfig (), WorkspaceConfig (), ROUTER_ADDRESS )
147+ cx1 = pool .acquire (READ_ACCESS , 30 , "test_db" , None )
148+ pool .release (cx1 )
149+ assert cx1 in pool .connections [cx1 .addr ]
150+ # simulate connection going stale (e.g. exceeding) and than breaking when
151+ # the pool tries to close the connection
152+ cx1 .stale .return_value = True
153+ cx_close_mock = cx1 .close
154+ if break_on_close :
155+ cx_close_mock_side_effect = cx_close_mock .side_effect
156+ cx_close_mock .side_effect = break_connection
157+ cx2 = pool .acquire (READ_ACCESS , 30 , "test_db" , None )
158+ pool .release (cx2 )
159+ assert cx1 .close .called_once ()
160+ assert cx2 is not cx1
161+ assert cx2 .addr == cx1 .addr
162+ assert cx1 not in pool .connections [cx1 .addr ]
163+ assert cx2 in pool .connections [cx2 .addr ]
0 commit comments