Skip to content

Commit fe66963

Browse files
authored
Merge branch 'master' into DOC-5821-update-query-example
2 parents ca8bf52 + 68483c1 commit fe66963

File tree

8 files changed

+263
-10
lines changed

8 files changed

+263
-10
lines changed

redis/_parsers/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ClusterDownError,
2828
ConnectionError,
2929
ExecAbortError,
30+
ExternalAuthProviderError,
3031
MasterDownError,
3132
ModuleError,
3233
MovedError,
@@ -60,6 +61,10 @@
6061
"Client sent AUTH, but no password is set": AuthenticationError,
6162
}
6263

64+
EXTERNAL_AUTH_PROVIDER_ERROR = {
65+
"problem with LDAP service": ExternalAuthProviderError,
66+
}
67+
6368
logger = logging.getLogger(__name__)
6469

6570

@@ -81,6 +86,7 @@ class BaseParser(ABC):
8186
NO_SUCH_MODULE_ERROR: ModuleError,
8287
MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
8388
**NO_AUTH_SET_ERROR,
89+
**EXTERNAL_AUTH_PROVIDER_ERROR,
8490
},
8591
"OOM": OutOfMemoryError,
8692
"WRONGPASS": AuthenticationError,

redis/asyncio/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,7 @@ class BlockingConnectionPool(ConnectionPool):
13561356
def __init__(
13571357
self,
13581358
max_connections: int = 50,
1359-
timeout: Optional[int] = 20,
1359+
timeout: Optional[float] = 20,
13601360
connection_class: Type[AbstractConnection] = Connection,
13611361
queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
13621362
**connection_kwargs,

redis/connection.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2954,14 +2954,18 @@ def release(self, connection):
29542954
pass
29552955
self._locked = False
29562956

2957-
def disconnect(self):
2958-
"Disconnects all connections in the pool."
2957+
def disconnect(self, inuse_connections: bool = True):
2958+
"Disconnects either all connections in the pool or just the free connections."
29592959
self._checkpid()
29602960
try:
29612961
if self._in_maintenance:
29622962
self._lock.acquire()
29632963
self._locked = True
2964-
for connection in self._connections:
2964+
if inuse_connections:
2965+
connections = self._connections
2966+
else:
2967+
connections = self._get_free_connections()
2968+
for connection in connections:
29652969
connection.disconnect()
29662970
finally:
29672971
if self._locked:

redis/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,11 @@ class InvalidPipelineStack(RedisClusterException):
245245
"""
246246

247247
pass
248+
249+
250+
class ExternalAuthProviderError(ConnectionError):
251+
"""
252+
Raised when an external authentication provider returns an error.
253+
"""
254+
255+
pass

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,15 @@ def skip_if_resp_version(resp_version) -> _TestDecorator:
334334
return pytest.mark.skipif(check, reason=f"RESP version required != {resp_version}")
335335

336336

337+
def skip_if_hiredis_parser() -> _TestDecorator:
338+
try:
339+
import hiredis # noqa
340+
341+
return pytest.mark.skipif(True, reason="hiredis dependency found")
342+
except ImportError:
343+
return pytest.mark.skipif(False, reason="No hiredis dependency")
344+
345+
337346
def _get_client(
338347
cls, request, single_connection_client=True, flushdb=True, from_url=None, **kwargs
339348
):

tests/test_asyncio/test_connection_pool.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,20 @@ class DummyConnection(Connection):
9595

9696
def __init__(self, **kwargs):
9797
self.kwargs = kwargs
98+
self._connected = False
9899

99100
def repr_pieces(self):
100101
return [("id", id(self)), ("kwargs", self.kwargs)]
101102

102103
async def connect(self):
103-
pass
104+
self._connected = True
104105

105106
async def disconnect(self):
106-
pass
107+
self._connected = False
108+
109+
@property
110+
def is_connected(self):
111+
return self._connected
107112

108113
async def can_read_destructive(self, timeout: float = 0):
109114
return False
@@ -203,6 +208,20 @@ async def test_repr_contains_db_info_unix(self):
203208
expected = "path=/abc,db=1,client_name=test-client"
204209
assert expected in repr(pool)
205210

211+
async def test_pool_disconnect(self, master_host):
212+
connection_kwargs = {
213+
"host": master_host[0],
214+
"port": master_host[1],
215+
}
216+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
217+
conn = await pool.get_connection()
218+
await pool.disconnect(inuse_connections=True)
219+
assert not conn.is_connected
220+
221+
await conn.connect()
222+
await pool.disconnect(inuse_connections=False)
223+
assert conn.is_connected
224+
206225

207226
class TestBlockingConnectionPool:
208227
@asynccontextmanager
@@ -231,17 +250,21 @@ async def test_connection_creation(self, master_host):
231250
assert isinstance(connection, DummyConnection)
232251
assert connection.kwargs == connection_kwargs
233252

234-
async def test_disconnect(self, master_host):
235-
"""A regression test for #1047"""
253+
async def test_pool_disconnect(self, master_host):
236254
connection_kwargs = {
237255
"foo": "bar",
238256
"biz": "baz",
239257
"host": master_host[0],
240258
"port": master_host[1],
241259
}
242260
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
243-
await pool.get_connection()
261+
conn = await pool.get_connection()
244262
await pool.disconnect()
263+
assert not conn.is_connected
264+
265+
await conn.connect()
266+
await pool.disconnect(inuse_connections=False)
267+
assert conn.is_connected
245268

246269
async def test_multiple_connections(self, master_host):
247270
connection_kwargs = {"host": master_host[0], "port": master_host[1]}

tests/test_connection_pool.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ class DummyConnection:
2929
def __init__(self, **kwargs):
3030
self.kwargs = kwargs
3131
self.pid = os.getpid()
32+
self._sock = None
3233

3334
def connect(self):
34-
pass
35+
self._sock = mock.Mock()
36+
37+
def disconnect(self):
38+
self._sock = None
3539

3640
def can_read(self):
3741
return False
@@ -140,6 +144,21 @@ def test_repr_contains_db_info_unix(self):
140144
expected = "path=/abc,db=1,client_name=test-client"
141145
assert expected in repr(pool)
142146

147+
def test_pool_disconnect(self, master_host):
148+
connection_kwargs = {
149+
"host": master_host[0],
150+
"port": master_host[1],
151+
}
152+
pool = self.get_pool(connection_kwargs=connection_kwargs)
153+
154+
conn = pool.get_connection()
155+
pool.disconnect()
156+
assert not conn._sock
157+
158+
conn.connect()
159+
pool.disconnect(inuse_connections=False)
160+
assert conn._sock
161+
143162

144163
class TestBlockingConnectionPool:
145164
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
@@ -244,6 +263,23 @@ def test_initialise_pool_with_cache(self, master_host):
244263
)
245264
assert isinstance(pool.get_connection(), CacheProxyConnection)
246265

266+
def test_pool_disconnect(self, master_host):
267+
connection_kwargs = {
268+
"foo": "bar",
269+
"biz": "baz",
270+
"host": master_host[0],
271+
"port": master_host[1],
272+
}
273+
pool = self.get_pool(connection_kwargs=connection_kwargs)
274+
275+
conn = pool.get_connection()
276+
pool.disconnect()
277+
assert not conn._sock
278+
279+
conn.connect()
280+
pool.disconnect(inuse_connections=False)
281+
assert conn._sock
282+
247283

248284
class TestConnectionPoolURLParsing:
249285
def test_hostname(self):

tests/test_parsers/test_errors.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import socket
2+
from unittest.mock import patch
3+
4+
import pytest
5+
from redis.client import Redis
6+
from redis.exceptions import ExternalAuthProviderError
7+
from tests.conftest import skip_if_hiredis_parser
8+
9+
10+
class MockSocket:
11+
"""Mock socket that simulates Redis protocol responses."""
12+
13+
def __init__(self):
14+
self.sent_data = []
15+
self.closed = False
16+
self.pending_responses = []
17+
18+
def connect(self, address):
19+
pass
20+
21+
def send(self, data):
22+
"""Simulate sending data to Redis."""
23+
if self.closed:
24+
raise ConnectionError("Socket is closed")
25+
self.sent_data.append(data)
26+
27+
# Analyze the command and prepare appropriate response
28+
if b"HELLO" in data:
29+
response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.4.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n"
30+
self.pending_responses.append(response)
31+
elif b"SET" in data:
32+
response = b"+OK\r\n"
33+
self.pending_responses.append(response)
34+
elif b"GET" in data:
35+
# Extract key and provide appropriate response
36+
if b"hello" in data:
37+
response = b"$5\r\nworld\r\n"
38+
self.pending_responses.append(response)
39+
# Handle specific keys used in tests
40+
elif b"ldap_error" in data:
41+
self.pending_responses.append(b"-ERR problem with LDAP service\r\n")
42+
else:
43+
self.pending_responses.append(b"$-1\r\n") # NULL response
44+
else:
45+
self.pending_responses.append(b"+OK\r\n") # Default response
46+
47+
return len(data)
48+
49+
def sendall(self, data):
50+
"""Simulate sending all data to Redis."""
51+
return self.send(data)
52+
53+
def recv(self, bufsize):
54+
"""Simulate receiving data from Redis."""
55+
if self.closed:
56+
raise ConnectionError("Socket is closed")
57+
58+
# Use pending responses that were prepared when commands were sent
59+
if self.pending_responses:
60+
response = self.pending_responses.pop(0)
61+
return response[:bufsize] # Respect buffer size
62+
else:
63+
# No data available - this should block or raise an exception
64+
# For can_read checks, we should indicate no data is available
65+
import errno
66+
67+
raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable")
68+
69+
def recv_into(self, buffer, nbytes=0):
70+
"""
71+
Receive data from Redis and write it into the provided buffer.
72+
Returns the number of bytes written.
73+
74+
This method is used by the hiredis parser for efficient data reading.
75+
"""
76+
if self.closed:
77+
raise ConnectionError("Socket is closed")
78+
79+
# Use pending responses that were prepared when commands were sent
80+
if self.pending_responses:
81+
response = self.pending_responses.pop(0)
82+
83+
# Determine how many bytes to write
84+
if nbytes == 0:
85+
nbytes = len(buffer)
86+
87+
# Write data into the buffer (up to nbytes or response length)
88+
bytes_to_write = min(len(response), nbytes, len(buffer))
89+
buffer[:bytes_to_write] = response[:bytes_to_write]
90+
91+
return bytes_to_write
92+
else:
93+
# No data available - this should block or raise an exception
94+
# For can_read checks, we should indicate no data is available
95+
import errno
96+
97+
raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable")
98+
99+
def fileno(self):
100+
"""Return a fake file descriptor for select/poll operations."""
101+
return 1 # Fake file descriptor
102+
103+
def close(self):
104+
"""Simulate closing the socket."""
105+
self.closed = True
106+
self.address = None
107+
self.timeout = None
108+
109+
def settimeout(self, timeout):
110+
pass
111+
112+
def setsockopt(self, level, optname, value):
113+
pass
114+
115+
def setblocking(self, blocking):
116+
pass
117+
118+
def shutdown(self, how):
119+
pass
120+
121+
122+
class TestErrorParsing:
123+
def setup_method(self):
124+
"""Set up test fixtures with mocked sockets."""
125+
self.mock_sockets = []
126+
self.original_socket = socket.socket
127+
128+
# Mock socket creation to return our mock sockets
129+
def mock_socket_factory(*args, **kwargs):
130+
mock_sock = MockSocket()
131+
self.mock_sockets.append(mock_sock)
132+
return mock_sock
133+
134+
self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory)
135+
self.socket_patcher.start()
136+
137+
# Mock select.select to simulate data availability for reading
138+
def mock_select(rlist, wlist, xlist, timeout=0):
139+
# Check if any of the sockets in rlist have data available
140+
ready_sockets = []
141+
for sock in rlist:
142+
if hasattr(sock, "connected") and sock.connected and not sock.closed:
143+
# Only return socket as ready if it actually has data to read
144+
if hasattr(sock, "pending_responses") and sock.pending_responses:
145+
ready_sockets.append(sock)
146+
# Don't return socket as ready just because it received commands
147+
# Only when there are actual responses available
148+
return (ready_sockets, [], [])
149+
150+
self.select_patcher = patch("select.select", side_effect=mock_select)
151+
self.select_patcher.start()
152+
153+
def teardown_method(self):
154+
"""Clean up test fixtures."""
155+
self.socket_patcher.stop()
156+
self.select_patcher.stop()
157+
158+
@skip_if_hiredis_parser()
159+
@pytest.mark.parametrize("protocol_version", [2, 3])
160+
def test_external_auth_provider_error(self, protocol_version):
161+
client = Redis(
162+
protocol=protocol_version,
163+
)
164+
client.set("hello", "world")
165+
166+
with pytest.raises(ExternalAuthProviderError):
167+
client.get("ldap_error")

0 commit comments

Comments
 (0)