Skip to content

Commit 8fc1981

Browse files
authored
Fix removing connection twice from pool. (#597)
When trying to close a stale connection the driver count realize that the connection is dead on trying to send GOODBYE. This would cause the connection to make sure that all connections to the same address would get removed from the pool as well. Since this removal only happens as a side effect of `connection.close()` and does not always happen, the driver would still try to remove the (now already removed) connection form the pool after closure. Fixes: `ValueError: deque.remove(x): x not in deque`
1 parent f2cb000 commit 8fc1981

File tree

3 files changed

+89
-30
lines changed

3 files changed

+89
-30
lines changed

neo4j/io/__init__.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
]
3636

3737
import abc
38-
from collections import deque
38+
from collections import (
39+
defaultdict,
40+
deque,
41+
)
3942
from logging import getLogger
4043
from random import choice
4144
from select import select
@@ -618,7 +621,7 @@ def __init__(self, opener, pool_config, workspace_config):
618621
self.opener = opener
619622
self.pool_config = pool_config
620623
self.workspace_config = workspace_config
621-
self.connections = {}
624+
self.connections = defaultdict(deque)
622625
self.lock = RLock()
623626
self.cond = Condition(self.lock)
624627

@@ -640,35 +643,44 @@ def _acquire(self, address, timeout):
640643
timeout = self.workspace_config.connection_acquisition_timeout
641644

642645
with self.lock:
643-
try:
644-
connections = self.connections[address]
645-
except KeyError:
646-
connections = self.connections[address] = deque()
647-
648646
def time_remaining():
649647
t = timeout - (perf_counter() - t0)
650648
return t if t > 0 else 0
651649

652650
while True:
653651
# try to find a free connection in pool
654-
for connection in list(connections):
652+
for connection in list(self.connections.get(address, [])):
655653
if (connection.closed() or connection.defunct()
656654
or connection.stale()):
657655
# `close` is a noop on already closed connections.
658656
# This is to make sure that the connection is gracefully
659657
# closed, e.g. if it's just marked as `stale` but still
660658
# alive.
661659
connection.close()
662-
connections.remove(connection)
660+
try:
661+
self.connections.get(address, []).remove(connection)
662+
except ValueError:
663+
# If closure fails (e.g. because the server went
664+
# down), all connections to the same address will
665+
# be removed. Therefore, we silently ignore if the
666+
# connection isn't in the pool anymore.
667+
pass
663668
continue
664669
if not connection.in_use:
665670
connection.in_use = True
666671
return connection
667672
# all connections in pool are in-use
668-
infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf"))
669-
can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size
673+
connections = self.connections[address]
674+
max_pool_size = self.pool_config.max_connection_pool_size
675+
infinite_pool_size = (max_pool_size < 0
676+
or max_pool_size == float("inf"))
677+
can_create_new_connection = (
678+
infinite_pool_size
679+
or len(connections) < max_pool_size
680+
)
670681
if can_create_new_connection:
671-
timeout = min(self.pool_config.connection_timeout, time_remaining())
682+
timeout = min(self.pool_config.connection_timeout,
683+
time_remaining())
672684
try:
673685
connection = self.opener(address, timeout)
674686
except ServiceUnavailable:

tests/unit/io/test_neo4j_pool.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525

2626
from ..work import FakeConnection
2727

28-
from neo4j import READ_ACCESS
28+
from neo4j import (
29+
READ_ACCESS,
30+
WRITE_ACCESS,
31+
)
2932
from neo4j.addressing import ResolvedAddress
3033
from neo4j.conf import (
3134
PoolConfig,
@@ -35,23 +38,24 @@
3538
from neo4j.io import Neo4jPool
3639

3740

41+
ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
42+
READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host")
43+
WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host")
44+
45+
3846
@pytest.fixture()
3947
def opener():
40-
def open_(*_, **__):
48+
def open_(addr, timeout):
4149
connection = FakeConnection()
50+
connection.addr = addr
51+
connection.timeout = timeout
4252
route_mock = Mock()
4353
route_mock.return_value = [{
4454
"ttl": 1000,
4555
"servers": [
46-
{"addresses": ["1.2.3.1:9001"], "role": "ROUTE"},
47-
{
48-
"addresses": ["1.2.3.10:9010", "1.2.3.11:9011"],
49-
"role": "READ"
50-
},
51-
{
52-
"addresses": ["1.2.3.20:9020", "1.2.3.21:9021"],
53-
"role": "WRITE"
54-
},
56+
{"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"},
57+
{"addresses": [str(READER_ADDRESS)], "role": "READ"},
58+
{"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"},
5559
],
5660
}]
5761
connection.attach_mock(route_mock, "route")
@@ -65,8 +69,7 @@ def open_(*_, **__):
6569

6670

6771
def test_acquires_new_routing_table_if_deleted(opener):
68-
address = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
69-
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address)
72+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
7073
cx = pool.acquire(READ_ACCESS, 30, "test_db", None)
7174
pool.release(cx)
7275
assert pool.routing_tables.get("test_db")
@@ -79,8 +82,7 @@ def test_acquires_new_routing_table_if_deleted(opener):
7982

8083

8184
def test_acquires_new_routing_table_if_stale(opener):
82-
address = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
83-
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address)
85+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
8486
cx = pool.acquire(READ_ACCESS, 30, "test_db", None)
8587
pool.release(cx)
8688
assert pool.routing_tables.get("test_db")
@@ -94,8 +96,7 @@ def test_acquires_new_routing_table_if_stale(opener):
9496

9597

9698
def test_removes_old_routing_table(opener):
97-
address = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
98-
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), address)
99+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
99100
cx = pool.acquire(READ_ACCESS, 30, "test_db1", None)
100101
pool.release(cx)
101102
assert pool.routing_tables.get("test_db1")
@@ -113,3 +114,50 @@ def test_removes_old_routing_table(opener):
113114
assert pool.routing_tables["test_db1"].last_updated_time > old_value
114115
assert "test_db2" not in pool.routing_tables
115116

117+
118+
@pytest.mark.parametrize("type_", ("r", "w"))
119+
def test_chooses_right_connection_type(opener, type_):
120+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
121+
cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS,
122+
30, "test_db", None)
123+
pool.release(cx1)
124+
if type_ == "r":
125+
assert cx1.addr == READER_ADDRESS
126+
else:
127+
assert cx1.addr == WRITER_ADDRESS
128+
129+
130+
def test_reuses_connection(opener):
131+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
132+
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
133+
pool.release(cx1)
134+
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
135+
assert cx1 is cx2
136+
137+
138+
@pytest.mark.parametrize("break_on_close", (True, False))
139+
def test_closes_stale_connections(opener, break_on_close):
140+
def break_connection():
141+
pool.deactivate(cx1.addr)
142+
143+
if cx_close_mock_side_effect:
144+
cx_close_mock_side_effect()
145+
146+
pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
147+
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
148+
pool.release(cx1)
149+
assert cx1 in pool.connections[cx1.addr]
150+
# simulate connection going stale (e.g. exceeding) and than breaking when
151+
# the pool tries to close the connection
152+
cx1.stale.return_value = True
153+
cx_close_mock = cx1.close
154+
if break_on_close:
155+
cx_close_mock_side_effect = cx_close_mock.side_effect
156+
cx_close_mock.side_effect = break_connection
157+
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
158+
pool.release(cx2)
159+
assert cx1.close.called_once()
160+
assert cx2 is not cx1
161+
assert cx2.addr == cx1.addr
162+
assert cx1 not in pool.connections[cx1.addr]
163+
assert cx2 in pool.connections[cx2.addr]

tests/unit/work/_fake_connection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def callback():
8787
return parent.__getattr__(name)
8888

8989

90-
9190
@pytest.fixture
9291
def fake_connection():
9392
return FakeConnection()

0 commit comments

Comments
 (0)