diff --git a/redis/connection.py b/redis/connection.py index 35e2bdf9ce..389529a1a7 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2954,14 +2954,18 @@ def release(self, connection): pass self._locked = False - def disconnect(self): - "Disconnects all connections in the pool." + def disconnect(self, inuse_connections: bool = True): + "Disconnects either all connections in the pool or just the free connections." self._checkpid() try: if self._in_maintenance: self._lock.acquire() self._locked = True - for connection in self._connections: + if inuse_connections: + connections = self._connections + else: + connections = self._get_free_connections() + for connection in connections: connection.disconnect() finally: if self._locked: diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index cb3dac9604..e658c14188 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -95,15 +95,20 @@ class DummyConnection(Connection): def __init__(self, **kwargs): self.kwargs = kwargs + self._connected = False def repr_pieces(self): return [("id", id(self)), ("kwargs", self.kwargs)] async def connect(self): - pass + self._connected = True async def disconnect(self): - pass + self._connected = False + + @property + def is_connected(self): + return self._connected async def can_read_destructive(self, timeout: float = 0): return False @@ -203,6 +208,20 @@ async def test_repr_contains_db_info_unix(self): expected = "path=/abc,db=1,client_name=test-client" assert expected in repr(pool) + async def test_pool_disconnect(self, master_host): + connection_kwargs = { + "host": master_host[0], + "port": master_host[1], + } + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + conn = await pool.get_connection() + await pool.disconnect(inuse_connections=True) + assert not conn.is_connected + + await conn.connect() + await pool.disconnect(inuse_connections=False) + assert conn.is_connected + class TestBlockingConnectionPool: @asynccontextmanager @@ -231,8 +250,7 @@ async def test_connection_creation(self, master_host): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs - async def test_disconnect(self, master_host): - """A regression test for #1047""" + async def test_pool_disconnect(self, master_host): connection_kwargs = { "foo": "bar", "biz": "baz", @@ -240,8 +258,13 @@ async def test_disconnect(self, master_host): "port": master_host[1], } async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - await pool.get_connection() + conn = await pool.get_connection() await pool.disconnect() + assert not conn.is_connected + + await conn.connect() + await pool.disconnect(inuse_connections=False) + assert conn.is_connected async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 2397f15600..7365c6ff13 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -29,9 +29,13 @@ class DummyConnection: def __init__(self, **kwargs): self.kwargs = kwargs self.pid = os.getpid() + self._sock = None def connect(self): - pass + self._sock = mock.Mock() + + def disconnect(self): + self._sock = None def can_read(self): return False @@ -140,6 +144,21 @@ def test_repr_contains_db_info_unix(self): expected = "path=/abc,db=1,client_name=test-client" assert expected in repr(pool) + def test_pool_disconnect(self, master_host): + connection_kwargs = { + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + + conn = pool.get_connection() + pool.disconnect() + assert not conn._sock + + conn.connect() + pool.disconnect(inuse_connections=False) + assert conn._sock + class TestBlockingConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): @@ -244,6 +263,23 @@ def test_initialise_pool_with_cache(self, master_host): ) assert isinstance(pool.get_connection(), CacheProxyConnection) + def test_pool_disconnect(self, master_host): + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + + conn = pool.get_connection() + pool.disconnect() + assert not conn._sock + + conn.connect() + pool.disconnect(inuse_connections=False) + assert conn._sock + class TestConnectionPoolURLParsing: def test_hostname(self):