Skip to content

Commit f2d97b2

Browse files
authored
Fixing sync BlockingConnectionPool's disconnect method to follow the definition in ConnectionPoolInterface (#3802)
* Fixing sync BlockingConnectionPool's disconnect method to follow the definition of the interface * Fixing linter errors
1 parent 55c6713 commit f2d97b2

File tree

3 files changed

+72
-9
lines changed

3 files changed

+72
-9
lines changed

redis/connection.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2954,14 +2954,18 @@ def release(self, connection):
29542954
pass
29552955
self._locked = False
29562956

2957-
def disconnect(self):
2958-
"Disconnects all connections in the pool."
2957+
def disconnect(self, inuse_connections: bool = True):
2958+
"Disconnects either all connections in the pool or just the free connections."
29592959
self._checkpid()
29602960
try:
29612961
if self._in_maintenance:
29622962
self._lock.acquire()
29632963
self._locked = True
2964-
for connection in self._connections:
2964+
if inuse_connections:
2965+
connections = self._connections
2966+
else:
2967+
connections = self._get_free_connections()
2968+
for connection in connections:
29652969
connection.disconnect()
29662970
finally:
29672971
if self._locked:

tests/test_asyncio/test_connection_pool.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,20 @@ class DummyConnection(Connection):
9595

9696
def __init__(self, **kwargs):
9797
self.kwargs = kwargs
98+
self._connected = False
9899

99100
def repr_pieces(self):
100101
return [("id", id(self)), ("kwargs", self.kwargs)]
101102

102103
async def connect(self):
103-
pass
104+
self._connected = True
104105

105106
async def disconnect(self):
106-
pass
107+
self._connected = False
108+
109+
@property
110+
def is_connected(self):
111+
return self._connected
107112

108113
async def can_read_destructive(self, timeout: float = 0):
109114
return False
@@ -203,6 +208,20 @@ async def test_repr_contains_db_info_unix(self):
203208
expected = "path=/abc,db=1,client_name=test-client"
204209
assert expected in repr(pool)
205210

211+
async def test_pool_disconnect(self, master_host):
212+
connection_kwargs = {
213+
"host": master_host[0],
214+
"port": master_host[1],
215+
}
216+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
217+
conn = await pool.get_connection()
218+
await pool.disconnect(inuse_connections=True)
219+
assert not conn.is_connected
220+
221+
await conn.connect()
222+
await pool.disconnect(inuse_connections=False)
223+
assert conn.is_connected
224+
206225

207226
class TestBlockingConnectionPool:
208227
@asynccontextmanager
@@ -231,17 +250,21 @@ async def test_connection_creation(self, master_host):
231250
assert isinstance(connection, DummyConnection)
232251
assert connection.kwargs == connection_kwargs
233252

234-
async def test_disconnect(self, master_host):
235-
"""A regression test for #1047"""
253+
async def test_pool_disconnect(self, master_host):
236254
connection_kwargs = {
237255
"foo": "bar",
238256
"biz": "baz",
239257
"host": master_host[0],
240258
"port": master_host[1],
241259
}
242260
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
243-
await pool.get_connection()
261+
conn = await pool.get_connection()
244262
await pool.disconnect()
263+
assert not conn.is_connected
264+
265+
await conn.connect()
266+
await pool.disconnect(inuse_connections=False)
267+
assert conn.is_connected
245268

246269
async def test_multiple_connections(self, master_host):
247270
connection_kwargs = {"host": master_host[0], "port": master_host[1]}

tests/test_connection_pool.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ class DummyConnection:
2929
def __init__(self, **kwargs):
3030
self.kwargs = kwargs
3131
self.pid = os.getpid()
32+
self._sock = None
3233

3334
def connect(self):
34-
pass
35+
self._sock = mock.Mock()
36+
37+
def disconnect(self):
38+
self._sock = None
3539

3640
def can_read(self):
3741
return False
@@ -140,6 +144,21 @@ def test_repr_contains_db_info_unix(self):
140144
expected = "path=/abc,db=1,client_name=test-client"
141145
assert expected in repr(pool)
142146

147+
def test_pool_disconnect(self, master_host):
148+
connection_kwargs = {
149+
"host": master_host[0],
150+
"port": master_host[1],
151+
}
152+
pool = self.get_pool(connection_kwargs=connection_kwargs)
153+
154+
conn = pool.get_connection()
155+
pool.disconnect()
156+
assert not conn._sock
157+
158+
conn.connect()
159+
pool.disconnect(inuse_connections=False)
160+
assert conn._sock
161+
143162

144163
class TestBlockingConnectionPool:
145164
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
@@ -244,6 +263,23 @@ def test_initialise_pool_with_cache(self, master_host):
244263
)
245264
assert isinstance(pool.get_connection(), CacheProxyConnection)
246265

266+
def test_pool_disconnect(self, master_host):
267+
connection_kwargs = {
268+
"foo": "bar",
269+
"biz": "baz",
270+
"host": master_host[0],
271+
"port": master_host[1],
272+
}
273+
pool = self.get_pool(connection_kwargs=connection_kwargs)
274+
275+
conn = pool.get_connection()
276+
pool.disconnect()
277+
assert not conn._sock
278+
279+
conn.connect()
280+
pool.disconnect(inuse_connections=False)
281+
assert conn._sock
282+
247283

248284
class TestConnectionPoolURLParsing:
249285
def test_hostname(self):

0 commit comments

Comments
 (0)