From 5139688d1e1a4ddb36125d63ef0ce14a8c3d3961 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 18 Aug 2025 10:57:43 +0300 Subject: [PATCH 1/4] Hitless upgrade: Support initial implementation for synchronous Redis client - no handshake, no failing over notifications support. (#3713) --- redis/_parsers/base.py | 98 +- redis/_parsers/hiredis.py | 26 +- redis/_parsers/resp3.py | 16 +- redis/client.py | 81 +- redis/connection.py | 743 ++++++++- redis/maintenance_events.py | 496 ++++++ tests/test_connection_pool.py | 10 +- tests/test_maintenance_events.py | 543 +++++++ tests/test_maintenance_events_handling.py | 1770 +++++++++++++++++++++ 9 files changed, 3691 insertions(+), 92 deletions(-) create mode 100644 redis/maintenance_events.py create mode 100644 tests/test_maintenance_events.py create mode 100644 tests/test_maintenance_events_handling.py diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 69d7b585dd..dd2d8b9de0 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -1,7 +1,13 @@ import sys from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError -from typing import Callable, List, Optional, Protocol, Union +from typing import Awaitable, Callable, List, Optional, Protocol, Union + +from redis.maintenance_events import ( + NodeMigratedEvent, + NodeMigratingEvent, + NodeMovingEvent, +) if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout @@ -158,7 +164,19 @@ async def read_response( raise NotImplementedError() -_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] +_INVALIDATION_MESSAGE = (b"invalidate", "invalidate") +_MOVING_MESSAGE = (b"MOVING", "MOVING") +_MIGRATING_MESSAGE = (b"MIGRATING", "MIGRATING") +_MIGRATED_MESSAGE = (b"MIGRATED", "MIGRATED") +_FAILING_OVER_MESSAGE = (b"FAILING_OVER", "FAILING_OVER") +_FAILED_OVER_MESSAGE = (b"FAILED_OVER", "FAILED_OVER") + +_MAINTENANCE_MESSAGES = ( + *_MIGRATING_MESSAGE, + *_MIGRATED_MESSAGE, + *_FAILING_OVER_MESSAGE, + *_FAILED_OVER_MESSAGE, +) class PushNotificationsParser(Protocol): @@ -166,16 +184,46 @@ class PushNotificationsParser(Protocol): pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None def handle_pubsub_push_response(self, response): """Handle pubsub push responses""" raise NotImplementedError() def handle_push_response(self, response, **kwargs): - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # TODO: PARSE latest format when available + host, port = response[2].decode().split(":") + ttl = response[1] + id = 1 # Hardcoded value until the notification starts including the id + notification = NodeMovingEvent(id, host, port, ttl) + return self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available + ttl = response[1] + id = 2 # Hardcoded value until the notification starts including the id + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available + id = 3 # Hardcoded value until the notification starts including the id + notification = NodeMigratedEvent(id) + else: + notification = None + if notification is not None: + return self.maintenance_push_handler_func(notification) + else: + return None def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -183,12 +231,20 @@ def set_pubsub_push_handler(self, pubsub_push_handler_func): def set_invalidation_push_handler(self, invalidation_push_handler_func): self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class AsyncPushNotificationsParser(Protocol): """Protocol defining async RESP3-specific parsing functionality""" pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None + maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None async def handle_pubsub_push_response(self, response): """Handle pubsub push responses asynchronously""" @@ -196,10 +252,34 @@ async def handle_pubsub_push_response(self, response): async def handle_push_response(self, response, **kwargs): """Handle push responses asynchronously""" - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return await self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return await self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # push notification from enterprise cluster for node moving + # TODO: PARSE latest format when available + host, port = response[2].split(":") + ttl = response[1] + id = 1 # Hardcoded value for async parser + notification = NodeMovingEvent(id, host, port, ttl) + return await self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available + ttl = response[1] + id = 2 # Hardcoded value for async parser + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available + id = 3 # Hardcoded value for async parser + notification = NodeMigratedEvent(id) + return await self.maintenance_push_handler_func(notification) def set_pubsub_push_handler(self, pubsub_push_handler_func): """Set the pubsub push handler function""" @@ -209,6 +289,12 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func): """Set the invalidation push handler function""" self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class _AsyncRESPBase(AsyncBaseParser): """Base class for async resp parsing""" diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index 521a58b26c..d82fe99cd9 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -47,6 +47,8 @@ def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self._buffer = bytearray(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None self._hiredis_PushNotificationType = None @@ -141,12 +143,15 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response + + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) return response if disable_decoding: @@ -169,12 +174,13 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + if push_request: return response + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) + elif ( isinstance(response, list) and response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 42c6652e31..72957b464c 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None def handle_pubsub_push_response(self, response): @@ -117,17 +119,21 @@ def _read_response(self, disable_decoding=False, push_request=False): for _ in range(int(response)) ] response = self.handle_push_response(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response + + return self._read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) + return response diff --git a/redis/client.py b/redis/client.py index 0e05b6f542..26837b673b 100755 --- a/redis/client.py +++ b/redis/client.py @@ -56,6 +56,10 @@ WatchError, ) from redis.lock import Lock +from redis.maintenance_events import ( + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, +) from redis.retry import Retry from redis.utils import ( _set_info_logger, @@ -244,6 +248,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, ) -> None: """ Initialize a new Redis client. @@ -368,6 +373,23 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") + if maintenance_events_config and self.connection_pool.get_protocol() not in [ + 3, + "3", + ]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + if maintenance_events_config and maintenance_events_config.enabled: + self.maintenance_events_pool_handler = MaintenanceEventPoolHandler( + self.connection_pool, maintenance_events_config + ) + self.connection_pool.set_maintenance_events_pool_handler( + self.maintenance_events_pool_handler + ) + else: + self.maintenance_events_pool_handler = None + self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client @@ -565,8 +587,15 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): + maintenance_events_config = ( + None + if self.maintenance_events_pool_handler is None + else self.maintenance_events_pool_handler.config + ) return self.__class__( - connection_pool=self.connection_pool, single_connection_client=True + connection_pool=self.connection_pool, + single_connection_client=True, + maintenance_events_config=maintenance_events_config, ) def __enter__(self): @@ -635,7 +664,11 @@ def _execute_command(self, *args, **options): ), lambda _: self._close_connection(conn), ) + finally: + if conn and conn.should_reconnect(): + self._close_connection(conn) + conn.connect() if self._single_connection_client: self.single_connection_lock.release() if not self.connection: @@ -686,11 +719,7 @@ def __init__(self, connection_pool): self.connection = self.connection_pool.get_connection() def __enter__(self): - self.connection.send_command("MONITOR") - # check that monitor returns 'OK', but don't return it to user - response = self.connection.read_response() - if not bool_ok(response): - raise RedisError(f"MONITOR failed: {response}") + self._start_monitor() return self def __exit__(self, *args): @@ -700,8 +729,13 @@ def __exit__(self, *args): def next_command(self): """Parse the response from a monitor command""" response = self.connection.read_response() + + if response is None: + return None + if isinstance(response, bytes): response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) m = self.monitor_re.match(command_data) db_id, client_info, command = m.groups() @@ -737,6 +771,14 @@ def listen(self): while True: yield self.next_command() + def _start_monitor(self): + self.connection.send_command("MONITOR") + # check that monitor returns 'OK', but don't return it to user + response = self.connection.read_response() + + if not bool_ok(response): + raise RedisError(f"MONITOR failed: {response}") + class PubSub: """ @@ -881,7 +923,7 @@ def clean_health_check_responses(self) -> None: """ ttl = 10 conn = self.connection - while self.health_check_response_counter > 0 and ttl > 0: + while conn and self.health_check_response_counter > 0 and ttl > 0: if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): response = self._execute(conn, conn.read_response) if self.is_health_check_response(response): @@ -911,11 +953,17 @@ def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return conn.retry.call_with_retry( + + if conn.should_reconnect(): + self._reconnect(conn) + + response = conn.retry.call_with_retry( lambda: command(*args, **kwargs), lambda _: self._reconnect(conn), ) + return response + def parse_response(self, block=True, timeout=0): """Parse the response from a publish/subscribe command""" conn = self.connection @@ -1125,6 +1173,7 @@ def get_message( return None response = self.parse_response(block=(timeout is None), timeout=timeout) + if response: return self.handle_message(response, ignore_subscribe_messages) return None @@ -1148,6 +1197,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): return None if isinstance(response, bytes): response = [b"pong", response] if response != b"PONG" else [b"pong", b""] + message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { @@ -1351,6 +1401,7 @@ def reset(self) -> None: # clean up the other instance attributes self.watching = False self.explicit_transaction = False + # we can safely return the connection to the pool here since we're # sure we're no longer WATCHing anything if self.connection: @@ -1510,6 +1561,7 @@ def _execute_transaction( if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) data.append(r) + return data def _execute_pipeline(self, connection, commands, raise_on_error): @@ -1517,16 +1569,17 @@ def _execute_pipeline(self, connection, commands, raise_on_error): all_cmds = connection.pack_commands([args for args, _ in commands]) connection.send_packed_command(all_cmds) - response = [] + responses = [] for args, options in commands: try: - response.append(self.parse_response(connection, args[0], **options)) + responses.append(self.parse_response(connection, args[0], **options)) except ResponseError as e: - response.append(e) + responses.append(e) if raise_on_error: - self.raise_first_error(commands, response) - return response + self.raise_first_error(commands, responses) + + return responses def raise_first_error(self, commands, response): for i, r in enumerate(response): @@ -1611,6 +1664,8 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: + # in reset() the connection is disconnected before returned to the pool if + # it is marked for reconnect. self.reset() def discard(self): diff --git a/redis/connection.py b/redis/connection.py index 47cb589569..1389f77476 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,18 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -19,6 +30,7 @@ CacheInterface, CacheKey, ) +from redis.typing import Number from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .auth.token import TokenInterface @@ -36,6 +48,12 @@ ResponseError, TimeoutError, ) +from .maintenance_events import ( + MaintenanceEventConnectionHandler, + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, + MaintenanceState, +) from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -159,6 +177,10 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass + @abstractmethod + def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler): + pass + @abstractmethod def get_protocol(self): pass @@ -222,6 +244,73 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @property + @abstractmethod + def maintenance_state(self) -> MaintenanceState: + """ + Returns the current maintenance state of the connection. + """ + pass + + @maintenance_state.setter + @abstractmethod + def maintenance_state(self, state: "MaintenanceState"): + """ + Sets the current maintenance state of the connection. + """ + pass + + @abstractmethod + def getpeername(self): + """ + Returns the peer name of the connection. + """ + pass + + @abstractmethod + def mark_for_reconnect(self): + """ + Mark the connection to be reconnected on the next command. + This is useful when a connection is moved to a different node. + """ + pass + + @abstractmethod + def should_reconnect(self): + """ + Returns True if the connection should be reconnected. + """ + pass + + @abstractmethod + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + """ + Update the timeout for the current socket. + """ + pass + + @abstractmethod + def set_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Updates temporary host address and timeout settings for the connection. + """ + pass + + @abstractmethod + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + """ + Resets temporary host address and timeout settings for the connection. + """ + pass + class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" @@ -233,7 +322,7 @@ def __init__( socket_timeout: Optional[float] = None, socket_connect_timeout: Optional[float] = None, retry_on_timeout: bool = False, - retry_on_error=SENTINEL, + retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, encoding: str = "utf-8", encoding_errors: str = "strict", decode_responses: bool = False, @@ -250,6 +339,12 @@ def __init__( protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, + orig_host_address: Optional[str] = None, + orig_socket_timeout: Optional[float] = None, + orig_socket_connect_timeout: Optional[float] = None, ): """ Initialize a new Connection. @@ -283,19 +378,22 @@ def __init__( self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: - retry_on_error = [] + retry_on_errors_list = [] + else: + retry_on_errors_list = list(retry_on_error) if retry_on_timeout: # Add TimeoutError to the errors list to retry on - retry_on_error.append(TimeoutError) - self.retry_on_error = retry_on_error - if retry or retry_on_error: + retry_on_errors_list.append(TimeoutError) + self.retry_on_error = retry_on_errors_list + if retry or self.retry_on_error: if retry is None: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) + if self.retry_on_error: + # Update the retry's supported errors with the specified errors + self.retry.update_supported_errors(self.retry_on_error) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval @@ -305,7 +403,6 @@ def __init__( self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size - self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 self._re_auth_token: Optional[TokenInterface] = None @@ -320,6 +417,39 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") # p = DEFAULT_RESP_VERSION self.protocol = p + if self.protocol == 3 and parser_class == DefaultParser: + parser_class = _RESP3Parser + self.set_parser(parser_class) + + if maintenance_events_config and maintenance_events_config.enabled: + if maintenance_events_pool_handler: + maintenance_events_pool_handler.set_connection(self) + self._parser.set_node_moving_push_handler( + maintenance_events_pool_handler.handle_event + ) + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler(self, maintenance_events_config) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + + self.orig_host_address = ( + orig_host_address if orig_host_address else self.host + ) + self.orig_socket_timeout = ( + orig_socket_timeout if orig_socket_timeout else self.socket_timeout + ) + self.orig_socket_connect_timeout = ( + orig_socket_connect_timeout + if orig_socket_connect_timeout + else self.socket_connect_timeout + ) + else: + self._maintenance_event_connection_handler = None + self._should_reconnect = False + self.maintenance_state = maintenance_state + self._command_packer = self._construct_command_packer(command_packer) def __repr__(self): @@ -375,6 +505,29 @@ def set_parser(self, parser_class): """ self._parser = parser_class(socket_read_size=self._socket_read_size) + def set_maintenance_event_pool_handler( + self, maintenance_event_pool_handler: MaintenanceEventPoolHandler + ): + maintenance_event_pool_handler.set_connection(self) + self._parser.set_node_moving_push_handler( + maintenance_event_pool_handler.handle_event + ) + + # Update maintenance event connection handler if it doesn't exist + if not self._maintenance_event_connection_handler: + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler( + self, maintenance_event_pool_handler.config + ) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + else: + self._maintenance_event_connection_handler.config = ( + maintenance_event_pool_handler.config + ) + def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) @@ -549,6 +702,8 @@ def disconnect(self, *args): conn_sock = self._sock self._sock = None + # reset the reconnect flag + self._should_reconnect = False if conn_sock is None: return @@ -626,6 +781,7 @@ def can_read(self, timeout=0): try: return self._parser.can_read(timeout) + except OSError as e: self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") @@ -732,6 +888,60 @@ def re_auth(self): self.read_response() self._re_auth_token = None + @property + def maintenance_state(self) -> MaintenanceState: + return self._maintenance_state + + @maintenance_state.setter + def maintenance_state(self, state: "MaintenanceState"): + self._maintenance_state = state + + def getpeername(self): + if not self._sock: + return None + return self._sock.getpeername()[0] + + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + if self._sock: + timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout + self._sock.settimeout(timeout) + self.update_parser_buffer_timeout(timeout) + + def update_parser_buffer_timeout(self, timeout: Optional[float] = None): + if self._parser and self._parser._buffer: + self._parser._buffer.socket_timeout = timeout + + def set_tmp_settings( + self, + tmp_host_address: Optional[Union[str, object]] = SENTINEL, + tmp_relax_timeout: Optional[float] = None, + ): + """ + The value of SENTINEL is used to indicate that the property should not be updated. + """ + if tmp_host_address is not SENTINEL: + self.host = tmp_host_address + if tmp_relax_timeout != -1: + self.socket_timeout = tmp_relax_timeout + self.socket_connect_timeout = tmp_relax_timeout + + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + if reset_host_address: + self.host = self.orig_host_address + if reset_relax_timeout: + self.socket_timeout = self.orig_socket_timeout + self.socket_connect_timeout = self.orig_socket_connect_timeout + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -764,6 +974,7 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None + for res in socket.getaddrinfo( self.host, self.port, self.socket_type, socket.SOCK_STREAM ): @@ -1415,6 +1626,32 @@ def __init__( connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) + if connection_kwargs.get( + "maintenance_events_pool_handler" + ) or connection_kwargs.get("maintenance_events_config"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + config = connection_kwargs.get("maintenance_events_config", None) or ( + connection_kwargs.get("maintenance_events_pool_handler").config + if connection_kwargs.get("maintenance_events_pool_handler") + else None + ) + + if config and config.enabled: + connection_kwargs.update( + { + "orig_host_address": connection_kwargs.get("host"), + "orig_socket_timeout": connection_kwargs.get( + "socket_timeout", None + ), + "orig_socket_connect_timeout": connection_kwargs.get( + "socket_connect_timeout", None + ), + } + ) + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -1449,6 +1686,43 @@ def get_protocol(self): """ return self.connection_kwargs.get("protocol", None) + def maintenance_events_pool_handler_enabled(self): + """ + Returns: + True if the maintenance events pool handler is enabled, False otherwise. + """ + maintenance_events_config = self.connection_kwargs.get( + "maintenance_events_config", None + ) + + return maintenance_events_config and maintenance_events_config.enabled + + def set_maintenance_events_pool_handler( + self, maintenance_events_pool_handler: MaintenanceEventPoolHandler + ): + self.connection_kwargs.update( + { + "maintenance_events_pool_handler": maintenance_events_pool_handler, + "maintenance_events_config": maintenance_events_pool_handler.config, + } + ) + + self._update_maintenance_events_configs_for_connections( + maintenance_events_pool_handler + ) + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + """Update the maintenance events config for all connections in the pool.""" + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + for conn in self._in_use_connections: + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + def reset(self) -> None: self._created_connections = 0 self._available_connections = [] @@ -1536,7 +1810,11 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read() and self.cache is None: + if ( + connection.can_read() + and self.cache is None + and not self.maintenance_events_pool_handler_enabled() + ): raise ConnectionError("Connection has data") except (ConnectionError, TimeoutError, OSError): connection.disconnect() @@ -1548,7 +1826,6 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # leak it self.release(connection) raise - return connection def get_encoder(self) -> Encoder: @@ -1566,12 +1843,13 @@ def make_connection(self) -> "ConnectionInterface": raise MaxConnectionsError("Too many connections") self._created_connections += 1 + kwargs = dict(self.connection_kwargs) + if self.cache is not None: return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock + self.connection_class(**kwargs), self.cache, self._lock ) - - return self.connection_class(**self.connection_kwargs) + return self.connection_class(**kwargs) def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" @@ -1585,6 +1863,8 @@ def release(self, connection: "Connection") -> None: return if self.owns_connection(connection): + if connection.should_reconnect(): + connection.disconnect() self._available_connections.append(connection) self._event_dispatcher.dispatch( AfterConnectionReleasedEvent(connection) @@ -1646,6 +1926,175 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) + def should_update_connection( + self, + conn: "Connection", + address_type_to_match: Literal["connected", "configured"] = "connected", + matching_address: Optional[str] = None, + ) -> bool: + """ + Check if the connection should be updated based on the matching address. + """ + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + return False + else: + if matching_address and conn.host != matching_address: + return False + return True + + def update_connection_settings( + self, + conn: "Connection", + state: Optional["MaintenanceState"] = None, + relax_timeout: Optional[float] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + """ + Update the settings for a single connection. + """ + if state: + conn.maintenance_state = state + + if reset_relax_timeout or reset_host_address: + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + + conn.update_current_socket_timeout(relax_timeout) + + def update_connections_settings( + self, + state: Optional["MaintenanceState"] = None, + relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + include_free_connections: bool = True, + ): + """ + Update the settings for all matching connections in the pool. + + This method does not create new connections. + This method does not affect the connection kwargs. + + :param state: The maintenance state to set for the connection. + :param relax_timeout: The relax timeout to set for the connection. + :param matching_address: The address to match for the connection. + :param address_type_to_match: The type of address to match. + :param reset_host_address: Whether to reset the host address to the original address. + :param reset_relax_timeout: Whether to reset the relax timeout to the original timeout. + :param include_free_connections: Whether to include free/available connections. + """ + with self._lock: + for conn in self._in_use_connections: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + + if include_free_connections: + for conn in self._available_connections: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + + def update_connection_kwargs( + self, + **kwargs, + ): + """ + Update the connection kwargs for all future connections. + + This method updates the connection kwargs for all future connections created by the pool. + Existing connections are not affected. + """ + self.connection_kwargs.update(kwargs) + + def update_active_connections_for_reconnect( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. + """ + with self._lock: + for conn in self._in_use_connections: + if self.should_update_connection(conn, "connected", moving_address_src): + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. + """ + with self._lock: + for conn in self._available_connections: + if self.should_update_connection(conn, "connected", moving_address_src): + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def _update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + ): + connection.mark_for_reconnect() + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout + ) + + def _disconnect_and_update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + ): + connection.disconnect() + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout + ) + async def _mock(self, error: RedisError): """ Dummy functions, needs to be passed as error callback to retry object. @@ -1699,6 +2148,8 @@ def __init__( ): self.queue_class = queue_class self.timeout = timeout + self._in_maintenance = False + self._locked = False super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -1707,16 +2158,27 @@ def __init__( def reset(self): # Create and fill up a thread safe queue with ``None`` values. - self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except Full: - break + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -1731,14 +2193,28 @@ def reset(self): def make_connection(self): "Make a fresh connection." - if self.cache is not None: - connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock - ) - else: - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + + if self.cache is not None: + connection = CacheProxyConnection( + self.connection_class(**self.connection_kwargs), + self.cache, + self._lock, + ) + else: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False @deprecated_args( args_to_warn=["*"], @@ -1764,16 +2240,27 @@ def get_connection(self, command_name=None, *keys, **options): # self.timeout then raise a ``ConnectionError``. connection = None try: - connection = self.pool.get(block=True, timeout=self.timeout) - except Empty: - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() + if self._in_maintenance: + self._lock.acquire() + self._locked = True + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False try: # ensure this connection is connected to Redis @@ -1801,25 +2288,167 @@ def release(self, connection): "Releases the connection back to the pool." # Make sure we haven't changed process. self._checkpid() - if not self.owns_connection(connection): - # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. - connection.disconnect() - self.pool.put_nowait(None) - return - # Put the connection back into the pool. try: - self.pool.put_nowait(connection) - except Full: - # perhaps the pool has been reset() after a fork? regardless, - # we don't want this connection - pass + if self._in_maintenance: + self._lock.acquire() + self._locked = True + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + connection.disconnect() + self.pool.put_nowait(None) + return + if connection.should_reconnect(): + connection.disconnect() + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False def disconnect(self): "Disconnects all connections in the pool." self._checkpid() - for connection in self._connections: - connection.disconnect() + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + for connection in self._connections: + connection.disconnect() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False + + def update_connections_settings( + self, + state: Optional["MaintenanceState"] = None, + relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + include_free_connections: bool = True, + ): + """ + Override base class method to work with BlockingConnectionPool's structure. + """ + with self._lock: + if include_free_connections: + for conn in tuple(self._connections): + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + + def update_active_connections_for_reconnect( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. + """ + with self._lock: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if moving_address_src and conn.getpeername() != moving_address_src: + continue + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[Number] = None, + moving_address_src: Optional[str] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. + """ + with self._lock: + existing_connections = self.pool.queue + + for conn in existing_connections: + if conn: + if moving_address_src and conn.getpeername() != moving_address_src: + continue + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def _update_maintenance_events_config_for_connections( + self, maintenance_events_config + ): + for conn in tuple(self._connections): + conn.maintenance_events_config = maintenance_events_config + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + """Update the maintenance events config for all connections in the pool.""" + with self._lock: + for conn in tuple(self._connections): + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + + def set_in_maintenance(self, in_maintenance: bool): + """ + Sets a flag that this Blocking ConnectionPool is in maintenance mode. + + This is used to prevent new connections from being created while we are in maintenance mode. + The pool will be in maintenance mode only when we are processing a MOVING event. + """ + self._in_maintenance = in_maintenance diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py new file mode 100644 index 0000000000..8c6c15c74a --- /dev/null +++ b/redis/maintenance_events.py @@ -0,0 +1,496 @@ +import enum +import logging +import threading +import time +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional, Union + +from redis.typing import Number + + +class MaintenanceState(enum.Enum): + NONE = "none" + MOVING = "moving" + MIGRATING = "migrating" + + +if TYPE_CHECKING: + from redis.connection import ( + BlockingConnectionPool, + ConnectionInterface, + ConnectionPool, + ) + + +class MaintenanceEvent(ABC): + """ + Base class for maintenance events sent through push messages by Redis server. + + This class provides common functionality for all maintenance events including + unique identification and TTL (Time-To-Live) functionality. + + Attributes: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + creation_time (float): Timestamp when the notification was created/read + """ + + def __init__(self, id: int, ttl: int): + """ + Initialize a new MaintenanceEvent with unique ID and TTL functionality. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + self.id = id + self.ttl = ttl + self.creation_time = time.monotonic() + self.expire_at = self.creation_time + self.ttl + + def is_expired(self) -> bool: + """ + Check if this event has expired based on its TTL + and creation time. + + Returns: + bool: True if the event has expired, False otherwise + """ + return time.monotonic() > (self.creation_time + self.ttl) + + @abstractmethod + def __repr__(self) -> str: + """ + Return a string representation of the maintenance event. + + This method must be implemented by all concrete subclasses. + + Returns: + str: String representation of the event + """ + pass + + @abstractmethod + def __eq__(self, other) -> bool: + """ + Compare two maintenance events for equality. + + This method must be implemented by all concrete subclasses. + Events are typically considered equal if they have the same id + and are of the same type. + + Args: + other: The other object to compare with + + Returns: + bool: True if the events are equal, False otherwise + """ + pass + + @abstractmethod + def __hash__(self) -> int: + """ + Return a hash value for the maintenance event. + + This method must be implemented by all concrete subclasses to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value for the event + """ + pass + + +class NodeMovingEvent(MaintenanceEvent): + """ + This event is received when a node is replaced with a new node + during cluster rebalancing or maintenance operations. + """ + + def __init__(self, id: int, new_node_host: str, new_node_port: int, ttl: int): + """ + Initialize a new NodeMovingEvent. + + Args: + id (int): Unique identifier for this event + new_node_host (str): Hostname or IP address of the new replacement node + new_node_port (int): Port number of the new replacement node + ttl (int): Time-to-live in seconds for this notification + """ + super().__init__(id, ttl) + self.new_node_host = new_node_host + self.new_node_port = new_node_port + + def __repr__(self) -> str: + expiry_time = self.expire_at + remaining = max(0, expiry_time - time.monotonic()) + + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"new_node_host='{self.new_node_host}', " + f"new_node_port={self.new_node_port}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMovingEvent events are considered equal if they have the same + id, new_node_host, and new_node_port. + """ + if not isinstance(other, NodeMovingEvent): + return False + return ( + self.id == other.id + and self.new_node_host == other.new_node_host + and self.new_node_port == other.new_node_port + ) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type, id, new_node_host, and new_node_port + """ + return hash((self.__class__, self.id, self.new_node_host, self.new_node_port)) + + +class NodeMigratingEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of migrating slots. + + This event is received when a node starts migrating its slots to another node + during cluster rebalancing or maintenance operations. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMigratingEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratingEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class NodeMigratedEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed migrating slots. + + This event is received when a node has finished migrating all its slots + to other nodes during cluster rebalancing or maintenance operations. + + Args: + id (int): Unique identifier for this event + """ + + DEFAULT_TTL = 5 + + def __init__(self, id: int): + super().__init__(id, NodeMigratedEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMigratedEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratedEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class MaintenanceEventsConfig: + """ + Configuration class for maintenance events handling behaviour. Events are received through + push notifications. + + This class defines how the Redis client should react to different push notifications + such as node moving, migrations, etc. in a Redis cluster. + + """ + + def __init__( + self, + enabled: bool = False, + proactive_reconnect: bool = True, + relax_timeout: Optional[Number] = 20, + ): + """ + Initialize a new MaintenanceEventsConfig. + + Args: + enabled (bool): Whether to enable maintenance events handling. + Defaults to False. + proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. + Defaults to True. + relax_timeout (Number): The relax timeout to use for the connection during maintenance. + If -1 is provided - the relax timeout is disabled. Defaults to 20. + + """ + self.enabled = enabled + self.relax_timeout = relax_timeout + self.proactive_reconnect = proactive_reconnect + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"enabled={self.enabled}, " + f"proactive_reconnect={self.proactive_reconnect}, " + f"relax_timeout={self.relax_timeout}, " + f")" + ) + + def is_relax_timeouts_enabled(self) -> bool: + """ + Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout. + If relax_timeout is set to None, it will make the operation blocking + and waiting until any response is received. + + Returns: + True if the relax_timeout is enabled, False otherwise. + """ + return self.relax_timeout != -1 + + +class MaintenanceEventPoolHandler: + def __init__( + self, + pool: Union["ConnectionPool", "BlockingConnectionPool"], + config: MaintenanceEventsConfig, + ) -> None: + self.pool = pool + self.config = config + self._processed_events = set() + self._lock = threading.RLock() + self.connection = None + + def set_connection(self, connection: "ConnectionInterface"): + self.connection = connection + + def remove_expired_notifications(self): + with self._lock: + for notification in tuple(self._processed_events): + if notification.is_expired(): + self._processed_events.remove(notification) + + def handle_event(self, notification: MaintenanceEvent): + self.remove_expired_notifications() + + if isinstance(notification, NodeMovingEvent): + return self.handle_node_moving_event(notification) + else: + logging.error(f"Unhandled notification type: {notification}") + + def handle_node_moving_event(self, event: NodeMovingEvent): + if ( + not self.config.proactive_reconnect + and not self.config.is_relax_timeouts_enabled() + ): + return + with self._lock: + if event in self._processed_events: + # nothing to do in the connection pool handling + # the event has already been handled or is expired + # just return + return + + with self.pool._lock: + if ( + self.config.proactive_reconnect + or self.config.is_relax_timeouts_enabled() + ): + moving_address_src = ( + self.connection.getpeername() if self.connection else None + ) + + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(True) + + # Update connection settings for all connections + self.pool.update_connections_settings( + state=MaintenanceState.MOVING, + relax_timeout=self.config.relax_timeout, + matching_address=moving_address_src, + address_type_to_match="connected", + include_free_connections=True, + ) + + if self.config.proactive_reconnect: + # take care for the active connections in the pool + # mark them for reconnect after they complete the current command + self.pool.update_active_connections_for_reconnect( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) + # take care for the inactive connections in the pool + # delete them and create new ones + self.pool.disconnect_and_reconfigure_free_connections( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) + + # Update config for new connections: + # Set state to MOVING + # update host + # if relax timeouts are enabled - update timeouts + kwargs: dict = { + "maintenance_state": MaintenanceState.MOVING, + "host": event.new_node_host, + } + if self.config.is_relax_timeouts_enabled(): + kwargs.update( + { + "socket_timeout": self.config.relax_timeout, + "socket_connect_timeout": self.config.relax_timeout, + } + ) + self.pool.update_connection_kwargs(**kwargs) + + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(False) + + threading.Timer( + event.ttl, self.handle_node_moved_event, args=(event,) + ).start() + + self._processed_events.add(event) + + def handle_node_moved_event(self, event: NodeMovingEvent): + with self._lock: + # if the current host in kwargs is not matching the event + # it means there has been a new moving event after this one + # and we don't need to revert the kwargs + if self.pool.connection_kwargs.get("host") == event.new_node_host: + orig_host = self.pool.connection_kwargs.get("orig_host_address") + orig_socket_timeout = self.pool.connection_kwargs.get( + "orig_socket_timeout" + ) + orig_connect_timeout = self.pool.connection_kwargs.get( + "orig_socket_connect_timeout" + ) + kwargs: dict = { + "maintenance_state": MaintenanceState.NONE, + "host": orig_host, + "socket_timeout": orig_socket_timeout, + "socket_connect_timeout": orig_connect_timeout, + } + self.pool.update_connection_kwargs(**kwargs) + + with self.pool._lock: + moving_address = event.new_node_host + reset_relax_timeout = self.config.is_relax_timeouts_enabled() + reset_host_address = self.config.proactive_reconnect + + self.pool.update_connections_settings( + relax_timeout=-1, + state=MaintenanceState.NONE, + matching_address=moving_address, + address_type_to_match="configured", + reset_relax_timeout=reset_relax_timeout, + reset_host_address=reset_host_address, + include_free_connections=True, + ) + + +class MaintenanceEventConnectionHandler: + def __init__( + self, connection: "ConnectionInterface", config: MaintenanceEventsConfig + ) -> None: + self.connection = connection + self.config = config + + def handle_event(self, event: MaintenanceEvent): + if isinstance(event, NodeMigratingEvent): + return self.handle_migrating_event(event) + elif isinstance(event, NodeMigratedEvent): + return self.handle_migration_completed_event(event) + else: + logging.error(f"Unhandled event type: {event}") + + def handle_migrating_event(self, notification: NodeMigratingEvent): + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): + return + self.connection.maintenance_state = MaintenanceState.MIGRATING + self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) + # extend the timeout for all created connections + self.connection.update_current_socket_timeout(self.config.relax_timeout) + + def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + # Only reset timeouts if state is not MOVING and relax timeouts are enabled + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): + return + self.connection.reset_tmp_settings(reset_relax_timeout=True) + # Node migration completed - reset the connection + # timeouts by providing -1 as the relax timeout + self.connection.update_current_socket_timeout(-1) + self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 3a4896f2a3..1eb68d3775 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -33,6 +33,9 @@ def connect(self): def can_read(self): return False + def should_reconnect(self): + return False + class TestConnectionPool: def get_pool( @@ -50,10 +53,14 @@ def get_pool( return pool def test_connection_creation(self): - connection_kwargs = {"foo": "bar", "biz": "baz"} + connection_kwargs = { + "foo": "bar", + "biz": "baz", + } pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -149,6 +156,7 @@ def test_connection_creation(self, master_host): "host": master_host[0], "port": master_host[1], } + pool = self.get_pool(connection_kwargs=connection_kwargs) connection = pool.get_connection() assert isinstance(connection, DummyConnection) diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py new file mode 100644 index 0000000000..c90fa5db4f --- /dev/null +++ b/tests/test_maintenance_events.py @@ -0,0 +1,543 @@ +import threading +from unittest.mock import Mock, patch, MagicMock +import pytest + +from redis.maintenance_events import ( + MaintenanceEvent, + NodeMovingEvent, + NodeMigratingEvent, + NodeMigratedEvent, + MaintenanceEventsConfig, + MaintenanceEventPoolHandler, + MaintenanceEventConnectionHandler, +) + + +class TestMaintenanceEvent: + """Test the base MaintenanceEvent class functionality through concrete subclasses.""" + + def test_abstract_class_cannot_be_instantiated(self): + """Test that MaintenanceEvent cannot be instantiated directly.""" + with patch("time.monotonic", return_value=1000): + with pytest.raises(TypeError): + MaintenanceEvent(id=1, ttl=10) # type: ignore + + def test_init_through_subclass(self): + """Test MaintenanceEvent initialization through concrete subclass.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.ttl == 10 + assert event.creation_time == 1000 + assert event.expire_at == 1010 + + @pytest.mark.parametrize( + ("current_time", "expected_expired_state"), + [ + (1005, False), + (1015, True), + ], + ) + def test_is_expired(self, current_time, expected_expired_state): + """Test is_expired returns False for non-expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=current_time): + assert event.is_expired() == expected_expired_state + + def test_is_expired_exact_boundary(self): + """Test is_expired at exact expiration boundary.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1010): # Exactly at expiration + assert not event.is_expired() + + with patch("time.monotonic", return_value=1011): # 1 second past expiration + assert event.is_expired() + + +class TestNodeMovingEvent: + """Test the NodeMovingEvent class.""" + + def test_init(self): + """Test NodeMovingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.new_node_host == "localhost" + assert event.new_node_port == 6379 + assert event.ttl == 10 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMovingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + repr_str = repr(event) + assert "NodeMovingEvent" in repr_str + assert "id=1" in repr_str + assert "new_node_host='localhost'" in repr_str + assert "new_node_port=6379" in repr_str + assert "ttl=10" in repr_str + assert "remaining=5.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_same_id_host_port(self): + """Test equality for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert event1 == event2 + + def test_equality_same_id_different_host(self): + """Test inequality for events with same id but different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_same_id_different_port(self): + """Test inequality for events with same id but different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_id(self): + """Test inequality for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_type(self): + """Test inequality for events of different types.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMigratingEvent(id=1, ttl=10) + assert event1 != event2 + + def test_hash_same_id_host_port(self): + """Test hash consistency for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert hash(event1) == hash(event2) + + def test_hash_different_host(self): + """Test hash difference for events with different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_port(self): + """Test hash difference for events with different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_id(self): + """Test hash difference for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_set_functionality(self): + """Test that events can be used in sets correctly.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Same id, host, port - should be considered the same + event3 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6380, ttl=10 + ) # Same id but different host/port - should be different + event4 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) # Different id - should be different + + event_set = {event1, event2, event3, event4} + assert len(event_set) == 3 # event1 and event2 should be considered the same + + +class TestNodeMigratingEvent: + """Test the NodeMigratingEvent class.""" + + def test_init(self): + """Test NodeMigratingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMigratingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeMigratingEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratingEvent.""" + event1 = NodeMigratingEvent(id=1, ttl=5) + event2 = NodeMigratingEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeMigratingEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeMigratedEvent: + """Test the NodeMigratedEvent class.""" + + def test_init(self): + """Test NodeMigratedEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeMigratedEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeMigratedEvent.DEFAULT_TTL == 5 + event = NodeMigratedEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeMigratedEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeMigratedEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratedEvent.""" + event1 = NodeMigratedEvent(id=1) + event2 = NodeMigratedEvent(id=1) # Same id + event3 = NodeMigratedEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestMaintenanceEventsConfig: + """Test the MaintenanceEventsConfig class.""" + + def test_init_defaults(self): + """Test MaintenanceEventsConfig initialization with defaults.""" + config = MaintenanceEventsConfig() + assert config.enabled is False + assert config.proactive_reconnect is True + assert config.relax_timeout == 20 + + def test_init_custom_values(self): + """Test MaintenanceEventsConfig initialization with custom values.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + assert config.enabled is True + assert config.proactive_reconnect is False + assert config.relax_timeout == 30 + + def test_repr(self): + """Test MaintenanceEventsConfig string representation.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + repr_str = repr(config) + assert "MaintenanceEventsConfig" in repr_str + assert "enabled=True" in repr_str + assert "proactive_reconnect=False" in repr_str + assert "relax_timeout=30" in repr_str + + def test_is_relax_timeouts_enabled_true(self): + """Test is_relax_timeouts_enabled returns True for positive timeout.""" + config = MaintenanceEventsConfig(relax_timeout=20) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_false(self): + """Test is_relax_timeouts_enabled returns False for -1 timeout.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + assert config.is_relax_timeouts_enabled() is False + + def test_is_relax_timeouts_enabled_zero(self): + """Test is_relax_timeouts_enabled returns True for zero timeout.""" + config = MaintenanceEventsConfig(relax_timeout=0) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_none(self): + """Test is_relax_timeouts_enabled returns True for None timeout.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.is_relax_timeouts_enabled() is True + + def test_relax_timeout_none_is_saved_as_none(self): + """Test that None value for relax_timeout is saved as None.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.relax_timeout is None + + +class TestMaintenanceEventPoolHandler: + """Test the MaintenanceEventPoolHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_pool = Mock() + self.mock_pool._lock = MagicMock() + self.mock_pool._lock.__enter__.return_value = None + self.mock_pool._lock.__exit__.return_value = None + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=20 + ) + self.handler = MaintenanceEventPoolHandler(self.mock_pool, self.config) + + def test_init(self): + """Test MaintenanceEventPoolHandler initialization.""" + assert self.handler.pool == self.mock_pool + assert self.handler.config == self.config + assert isinstance(self.handler._processed_events, set) + assert isinstance(self.handler._lock, type(threading.RLock())) + + def test_remove_expired_notifications(self): + """Test removal of expired notifications.""" + with patch("time.monotonic", return_value=1000): + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="host2", new_node_port=6380, ttl=5 + ) + self.handler._processed_events.add(event1) + self.handler._processed_events.add(event2) + + # Move time forward but not enough to expire event2 (expires at 1005) + with patch("time.monotonic", return_value=1003): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 in self.handler._processed_events # Not expired yet + + # Move time forward to expire event2 but not event1 + with patch("time.monotonic", return_value=1006): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 not in self.handler._processed_events # Now expired + + def test_handle_event_node_moving(self): + """Test handling of NodeMovingEvent.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch.object(self.handler, "handle_node_moving_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMigratingEvent(id=1, ttl=5) # Not handled by pool handler + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_node_moving_event_disabled_config(self): + """Test node moving event handling when both features are disabled.""" + config = MaintenanceEventsConfig(proactive_reconnect=False, relax_timeout=-1) + handler = MaintenanceEventPoolHandler(self.mock_pool, config) + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = handler.handle_node_moving_event(event) + assert result is None + assert event not in handler._processed_events + + def test_handle_node_moving_event_already_processed(self): + """Test node moving event handling when event already processed.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.handler._processed_events.add(event) + + result = self.handler.handle_node_moving_event(event) + assert result is None + + def test_handle_node_moving_event_success(self): + """Test successful node moving event handling.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with ( + patch("threading.Timer") as mock_timer, + patch("time.monotonic", return_value=1000), + ): + self.handler.handle_node_moving_event(event) + + # Verify timer was started + mock_timer.assert_called_once_with( + event.ttl, self.handler.handle_node_moved_event, args=(event,) + ) + mock_timer.return_value.start.assert_called_once() + + # Verify event was added to processed set + assert event in self.handler._processed_events + + # Verify pool methods were called + self.mock_pool.update_connections_settings.assert_called_once() + + def test_handle_node_moved_event(self): + """Test handling of node moved event (cleanup).""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.mock_pool.connection_kwargs = {"host": "localhost"} + self.handler.handle_node_moved_event(event) + + # Verify cleanup methods were called + self.mock_pool.update_connections_settings.assert_called_once() + + +class TestMaintenanceEventConnectionHandler: + """Test the MaintenanceEventConnectionHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_connection = Mock() + self.config = MaintenanceEventsConfig(enabled=True, relax_timeout=20) + self.handler = MaintenanceEventConnectionHandler( + self.mock_connection, self.config + ) + + def test_init(self): + """Test MaintenanceEventConnectionHandler initialization.""" + assert self.handler.connection == self.mock_connection + assert self.handler.config == self.config + + def test_handle_event_migrating(self): + """Test handling of NodeMigratingEvent.""" + event = NodeMigratingEvent(id=1, ttl=5) + + with patch.object(self.handler, "handle_migrating_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_migrated(self): + """Test handling of NodeMigratedEvent.""" + event = NodeMigratedEvent(id=1) + + with patch.object( + self.handler, "handle_migration_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_migrating_event_disabled(self): + """Test migrating event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratingEvent(id=1, ttl=5) + + result = handler.handle_migrating_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migrating_event_success(self): + """Test successful migrating event handling.""" + event = NodeMigratingEvent(id=1, ttl=5) + + self.handler.handle_migrating_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) + self.mock_connection.set_tmp_settings.assert_called_once_with( + tmp_relax_timeout=20 + ) + + def test_handle_migration_completed_event_disabled(self): + """Test migration completed event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratedEvent(id=1) + + result = handler.handle_migration_completed_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migration_completed_event_success(self): + """Test successful migration completed event handling.""" + event = NodeMigratedEvent(id=1) + + self.handler.handle_migration_completed_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) + self.mock_connection.reset_tmp_settings.assert_called_once_with( + reset_relax_timeout=True + ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py new file mode 100644 index 0000000000..8db8d182a7 --- /dev/null +++ b/tests/test_maintenance_events_handling.py @@ -0,0 +1,1770 @@ +import socket +import threading +from typing import List, Union +from unittest.mock import patch + +import pytest +from time import sleep + +from redis import Redis +from redis.connection import ( + AbstractConnection, + ConnectionPool, + BlockingConnectionPool, + MaintenanceState, +) +from redis.maintenance_events import ( + MaintenanceEventsConfig, + NodeMigratingEvent, + MaintenanceEventPoolHandler, + NodeMovingEvent, + NodeMigratedEvent, +) + + +AFTER_MOVING_ADDRESS = "1.2.3.4:6379" +DEFAULT_ADDRESS = "12.45.34.56:6379" +MOVING_TIMEOUT = 1 + + +class Helpers: + """Helper class containing static methods for validation in maintenance events tests.""" + + @staticmethod + def validate_in_use_connections_state( + in_use_connections: List[AbstractConnection], + expected_state=MaintenanceState.NONE, + expected_should_reconnect: Union[bool, str] = True, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ): + """Helper method to validate state of in-use connections.""" + + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for connection in in_use_connections: + if expected_should_reconnect != "any": + assert connection._should_reconnect == expected_should_reconnect + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + if connection._sock is not None: + assert connection._sock.gettimeout() == expected_current_socket_timeout + assert connection._sock.connected is True + if expected_current_peername != "any": + assert ( + connection._sock.getpeername()[0] == expected_current_peername + ) + assert connection.maintenance_state == expected_state + + @staticmethod + def validate_free_connections_state( + pool, + should_be_connected_count=0, + connected_to_tmp_address=False, + tmp_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_state=MaintenanceState.MOVING, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ): + """Helper method to validate state of free/available connections.""" + + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + else: + raise ValueError(f"Unsupported pool type: {type(pool)}") + + connected_count = 0 + for connection in free_connections: + assert connection._should_reconnect is False + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + assert connection.maintenance_state == expected_state + if connection._sock is not None: + assert connection._sock.connected is True + if connected_to_tmp_address and tmp_address != "any": + assert connection._sock.getpeername()[0] == tmp_address + connected_count += 1 + assert connected_count == should_be_connected_count + + @staticmethod + def validate_conn_kwargs( + pool, + expected_host_address, + expected_port, + expected_socket_timeout, + expected_socket_connect_timeout, + expected_orig_host_address, + expected_orig_socket_timeout, + expected_orig_socket_connect_timeout, + ): + """Helper method to validate connection kwargs.""" + assert pool.connection_kwargs["host"] == expected_host_address + assert pool.connection_kwargs["port"] == expected_port + assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout + assert ( + pool.connection_kwargs["socket_connect_timeout"] + == expected_socket_connect_timeout + ) + assert ( + pool.connection_kwargs.get("orig_host_address", None) + == expected_orig_host_address + ) + assert ( + pool.connection_kwargs.get("orig_socket_timeout", None) + == expected_orig_socket_timeout + ) + assert ( + pool.connection_kwargs.get("orig_socket_connect_timeout", None) + == expected_orig_socket_connect_timeout + ) + + +class MockSocket: + """Mock socket that simulates Redis protocol responses.""" + + def __init__(self): + self.connected = False + self.address = None + self.sent_data = [] + self.closed = False + self.command_count = 0 + self.pending_responses = [] + # Track socket timeout changes for maintenance events validation + self.timeout = None + self.thread_timeouts = {} # Track last applied timeout per thread + self.moving_sent = False + + def connect(self, address): + """Simulate socket connection.""" + self.connected = True + self.address = address + + def send(self, data): + """Simulate sending data to Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + self.sent_data.append(data) + + # Analyze the command and prepare appropriate response + if b"HELLO" in data: + response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" + self.pending_responses.append(response) + elif b"SET" in data: + response = b"+OK\r\n" + + # Check if this is a key that should trigger a push message + if b"key_receive_migrating_" in data or b"key_receive_migrating" in data: + # MIGRATING push message before SET key_receive_migrating_X response + # Format: >2\r\n$9\r\nMIGRATING\r\n:10\r\n (2 elements: MIGRATING, ttl) + migrating_push = ">2\r\n$9\r\nMIGRATING\r\n:10\r\n" + response = migrating_push.encode() + response + elif b"key_receive_migrated_" in data or b"key_receive_migrated" in data: + # MIGRATED push message before SET key_receive_migrated_X response + # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) + migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" + response = migrated_push.encode() + response + elif b"key_receive_moving_" in data: + # MOVING push message before SET key_receive_moving_X response + # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) + # Note: Using + instead of $ to send as simple string instead of bulk string + moving_push = f">3\r\n$6\r\nMOVING\r\n:{MOVING_TIMEOUT}\r\n+{AFTER_MOVING_ADDRESS}\r\n" + response = moving_push.encode() + response + + self.pending_responses.append(response) + elif b"GET" in data: + # Extract key and provide appropriate response + if b"hello" in data: + response = b"$5\r\nworld\r\n" + self.pending_responses.append(response) + # Handle specific keys used in tests + elif b"key_receive_moving_0" in data: + self.pending_responses.append(b"$8\r\nvalue3_0\r\n") + elif b"key_receive_migrated_0" in data: + self.pending_responses.append(b"$13\r\nmigrated_value\r\n") + elif b"key_receive_migrating" in data: + self.pending_responses.append(b"$6\r\nvalue2\r\n") + elif b"key_receive_migrated" in data: + self.pending_responses.append(b"$6\r\nvalue3\r\n") + elif b"key1" in data: + self.pending_responses.append(b"$6\r\nvalue1\r\n") + else: + self.pending_responses.append(b"$-1\r\n") # NULL response + else: + self.pending_responses.append(b"+OK\r\n") # Default response + + self.command_count += 1 + return len(data) + + def sendall(self, data): + """Simulate sending all data to Redis.""" + return self.send(data) + + def recv(self, bufsize): + """Simulate receiving data from Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + + # Use pending responses that were prepared when commands were sent + if self.pending_responses: + response = self.pending_responses.pop(0) + if b"MOVING" in response: + self.moving_sent = True + return response[:bufsize] # Respect buffer size + else: + # No data available - this should block or raise an exception + # For can_read checks, we should indicate no data is available + import errno + + raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + + def fileno(self): + """Return a fake file descriptor for select/poll operations.""" + return 1 # Fake file descriptor + + def close(self): + """Simulate closing the socket.""" + self.closed = True + self.connected = False + self.address = None + self.timeout = None + self.thread_timeouts = {} + + def settimeout(self, timeout): + """Simulate setting socket timeout and track changes per thread.""" + self.timeout = timeout + + # Track last applied timeout with thread_id information added + thread_id = threading.current_thread().ident + self.thread_timeouts[thread_id] = timeout + + def gettimeout(self): + """Simulate getting socket timeout.""" + return self.timeout + + def setsockopt(self, level, optname, value): + """Simulate setting socket options.""" + pass + + def getpeername(self): + """Simulate getting peer name.""" + return self.address + + def getsockname(self): + """Simulate getting socket name.""" + return (self.address.split(":")[0], 12345) + + def shutdown(self, how): + """Simulate socket shutdown.""" + pass + + +class TestMaintenanceEventsHandlingSingleProxy: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if hasattr(sock, "pending_responses") and sock.pending_responses: + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + + def _get_client( + self, + pool_class, + max_connections=10, + maintenance_events_config=None, + setup_pool_handler=False, + ): + """Helper method to create a pool and Redis client with maintenance events configuration. + + Args: + pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool) + max_connections: Maximum number of connections in the pool (default: 10) + maintenance_events_config: Optional MaintenanceEventsConfig to use. If not provided, + uses self.config from setup_method (default: None) + setup_pool_handler: Whether to set up pool handler for moving events (default: False) + + Returns: + tuple: (test_pool, test_redis_client) + """ + config = ( + maintenance_events_config + if maintenance_events_config is not None + else self.config + ) + + test_pool = pool_class( + host=DEFAULT_ADDRESS.split(":")[0], + port=int(DEFAULT_ADDRESS.split(":")[1]), + max_connections=max_connections, + protocol=3, # Required for maintenance events + maintenance_events_config=config, + ) + test_redis_client = Redis(connection_pool=test_pool) + + # Set up pool handler for moving events if requested + if setup_pool_handler: + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + return test_redis_client + + def _validate_connection_handlers(self, conn, pool_handler, config): + """Helper method to validate connection handlers are properly set.""" + # Test that the node moving handler function is correctly set + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is config + + def _validate_current_timeout(self, expected_timeout, error_msg=None): + """Helper method to validate the current timeout for the calling thread.""" + actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + f"{error_msg or ''}" + f"Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", + ) + + def _validate_disconnected(self, expected_count): + """Helper method to validate all socket timeouts""" + disconnected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.closed: + disconnected_sockets_count += 1 + assert disconnected_sockets_count == expected_count + + def _validate_connected(self, expected_count): + """Helper method to validate all socket timeouts""" + connected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.connected: + connected_sockets_count += 1 + assert connected_sockets_count == expected_count + + def _validate_all_timeouts(self, expected_timeout): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for mock_socket in self.mock_sockets: + assert mock_socket.gettimeout() == expected_timeout + + def test_client_initialization(self): + """Test that Redis client is created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + pool_handler = test_redis_client.connection_pool.connection_kwargs.get( + "maintenance_events_pool_handler" + ) + assert pool_handler is not None + assert pool_handler.config == self.config + + conn = test_redis_client.connection_pool.get_connection() + assert conn._should_reconnect is False + assert conn.orig_host_address == "localhost" + assert conn.orig_socket_timeout is None + + # Test that the node moving handler function is correctly set by + # comparing the underlying function and instance + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is self.config + + def test_maint_handler_init_for_existing_connections(self): + """Test that maintenance event handlers are properly set on existing and new connections + when configuration is enabled after client creation.""" + + # Create a Redis client with disabled maintenance events configuration + disabled_config = MaintenanceEventsConfig(enabled=False) + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=disabled_config, + ) + + # Extract an existing connection before enabling maintenance events + existing_conn = test_redis_client.connection_pool.get_connection() + + # Verify that maintenance events are initially disabled + assert existing_conn._parser.node_moving_push_handler_func is None + assert existing_conn._maintenance_event_connection_handler is None + assert existing_conn._parser.maintenance_push_handler_func is None + + # Create a new enabled configuration and set up pool handler + enabled_config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, enabled_config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + # Validate the existing connection after enabling maintenance events + # Both existing and new connections should now have full handler setup + self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) + + # Create a new connection and validate it has full handlers + new_conn = test_redis_client.connection_pool.get_connection() + self._validate_connection_handlers(new_conn, pool_handler, enabled_config) + + # Clean up connections + test_redis_client.connection_pool.release(existing_conn) + test_redis_client.connection_pool.release(new_conn) + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_connection_pool_creation_with_maintenance_events(self, pool_class): + """Test that connection pools are created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + max_connections = 3 if pool_class == BlockingConnectionPool else 10 + test_redis_client = self._get_client( + pool_class, max_connections=max_connections + ) + test_pool = test_redis_client.connection_pool + + try: + assert ( + test_pool.connection_kwargs.get("maintenance_events_config") + == self.config + ) + # Pool should have maintenance events enabled + assert test_pool.maintenance_events_pool_handler_enabled() is True + + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + test_pool.set_maintenance_events_pool_handler(pool_handler) + + # Validate that the handler is properly set on the pool + assert ( + test_pool.connection_kwargs.get("maintenance_events_pool_handler") + == pool_handler + ) + assert ( + test_pool.connection_kwargs.get("maintenance_events_config") + == pool_handler.config + ) + + # Verify that the pool handler has the correct configuration + assert pool_handler.pool == test_pool + assert pool_handler.config == self.config + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_redis_operations_with_mock_sockets(self, pool_class): + """ + Test basic Redis operations work with mocked sockets and proper response parsing. + Basically with test - the mocked socket is validated. + """ + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(pool_class, max_connections=5) + + try: + # Perform Redis operations that should work with our improved mock responses + result_set = test_redis_client.set("hello", "world") + result_get = test_redis_client.get("hello") + + # Verify operations completed successfully + assert result_set is True + assert result_get == b"world" + + # Verify socket interactions + assert len(self.mock_sockets) >= 1 + assert self.mock_sockets[0].connected + assert len(self.mock_sockets[0].sent_data) >= 2 # HELLO, SET, GET commands + + # Verify that the connection has maintenance event handler + connection = test_redis_client.connection_pool.get_connection() + assert hasattr(connection, "_maintenance_event_connection_handler") + test_redis_client.connection_pool.release(connection) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + def test_pool_handler_with_migrating_event(self): + """Test that pool handler correctly handles migrating events.""" + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(ConnectionPool) + test_pool = test_redis_client.connection_pool + + try: + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + + # Create a migrating event (not handled by pool handler) + migrating_event = NodeMigratingEvent(id=1, ttl=5) + + # Mock the required functions + with ( + patch.object( + pool_handler, "remove_expired_notifications" + ) as mock_remove_expired, + patch.object( + pool_handler, "handle_node_moving_event" + ) as mock_handle_moving, + patch("redis.maintenance_events.logging.error") as mock_logging_error, + ): + # Pool handler should return None for migrating events (not its responsibility) + pool_handler.handle_event(migrating_event) + + # Validate that remove_expired_notifications has been called once + mock_remove_expired.assert_called_once() + + # Validate that handle_node_moving_event hasn't been called + mock_handle_moving.assert_not_called() + + # Validate that logging.error has been called once + mock_logging_error.assert_called_once() + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migration_related_events_handling_integration(self, pool_class): + """ + Test full integration of migration-related events (MIGRATING/MIGRATED) handling. + + This test validates the complete migration lifecycle: + 1. Executes 5 Redis commands sequentially + 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating) + 3. Validates socket timeout is updated to relaxed value (30s) after MIGRATING + 4. Executes commands 3-4 while timeout remains relaxed + 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated) + 6. Validates socket timeout is restored after MIGRATED + 7. Tests both ConnectionPool and BlockingConnectionPool implementations + 8. Uses proper RESP3 push message format for realistic protocol simulation + """ + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(pool_class, max_connections=10) + + try: + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) + + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" + + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" + + # Step 4: Validate timeout was updated to relaxed value after MIGRATING + self._validate_current_timeout(30, "Right after MIGRATING is received. ") + + # Command 3: Another command while timeout is still relaxed + result3 = test_redis_client.get(key1) + + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected {expected_value3}, got {result3}" + ) + + # Command 4: Execute command (step 5) + result4 = test_redis_client.get(key_migrating) + + # Validate Command 4 result + expected_value4 = value_migrating.encode() + assert result4 == expected_value4, ( + f"Command 4 (GET key_receive_migrating) failed. Expected {expected_value4}, got {result4}" + ) + + # Step 6: Validate socket timeout is still relaxed during commands 3-4 + self._validate_current_timeout( + 30, + "Execute a command with a connection extracted from the pool (after it has received MIGRATING)", + ) + + # Command 5: This SET command will receive + # MIGRATED push message before actual response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result5 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_migrated) failed" + + # Step 8: Validate socket timeout is reversed back to original after MIGRATED + self._validate_current_timeout(None) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_event_with_disabled_relax_timeout(self, pool_class): + """ + Test migrating event handling when relax timeout is disabled. + + This test validates that when relax_timeout is disabled (-1): + 1. MIGRATING events are received and processed + 2. No timeout updates are applied to connections + 3. Socket timeouts remain unchanged during migration events + 4. Tests both ConnectionPool and BlockingConnectionPool implementations + """ + # Create config with disabled relax timeout + disabled_config = MaintenanceEventsConfig( + enabled=True, + relax_timeout=-1, # This means the relax timeout is Disabled + ) + + # Create a pool and Redis client with disabled relax timeout config + test_redis_client = self._get_client( + pool_class, max_connections=5, maintenance_events_config=disabled_config + ) + + try: + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) + + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" + + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" + + # Validate timeout was NOT updated (relax is disabled) + # Should remain at default timeout (None), not relaxed to 30s + self._validate_current_timeout(None) + + # Command 3: Another command to verify timeout remains unchanged + result3 = test_redis_client.get(key1) + + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" + ) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_related_events_handling_integration(self, pool_class): + """ + Test full integration of moving-related events (MOVING) handling with Redis commands. + + This test validates the complete MOVING event lifecycle: + 1. Creates multiple connections in the pool + 2. Executes a Redis command that triggers a MOVING push message + 3. Validates that pool configuration is updated with temporary + address and timeout - for new connections creation + 4. Validates that existing connections are marked for disconnection + 5. Tests both ConnectionPool and BlockingConnectionPool implementations + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(10): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 5 connections to be "in use" + in_use_connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + # the connection used for the command is expected to be reconnected to the new address + # before it is returned to the pool + result2 = test_redis_client.set(key_moving, value_moving) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_moving) failed" + + # Validate pool and connections settings were updated according to MOVING event + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + self._validate_disconnected(5) + self._validate_connected(6) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[ + 0 + ], # the in use connections reconnect when they complete their current task + ) + Helpers.validate_free_connections_state( + pool=test_redis_client.connection_pool, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + should_be_connected_count=1, + connected_to_tmp_address=True, + ) + # Wait for MOVING timeout to expire and the moving completed handler to run + sleep(MOVING_TIMEOUT + 0.5) + + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + Helpers.validate_free_connections_state( + pool=test_redis_client.connection_pool, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + should_be_connected_count=1, + connected_to_tmp_address=True, + expected_state=MaintenanceState.NONE, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_while_moving_not_expired(self, pool_class): + """ + Test creating new connections while MOVING event is active (not expired). + + This test validates that: + 1. After MOVING event is processed, new connections are created with temporary address + 2. New connections inherit the relaxed timeout settings + 3. Pool configuration is properly applied to newly created connections + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + ) + + # Now get several more connections to force creation of new ones + # This should create new connections with the temporary address + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with temporary address and relax timeout + # and when connecting those configs are used + # get_connection() returns a connection that is already connected + assert new_connection.host == AFTER_MOVING_ADDRESS.split(":")[0] + assert new_connection.socket_timeout is self.config.relax_timeout + # New connections should be connected to the temporary address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + assert ( + new_connection._sock.getpeername()[0] + == AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection._sock.gettimeout() == self.config.relax_timeout + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_after_moving_expires(self, pool_class): + """ + Test creating new connections after MOVING event expires. + + This test validates that: + 1. After MOVING timeout expires, new connections use original address + 2. Pool configuration is reset to original values + 3. New connections don't inherit temporary settings + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Wait for MOVING timeout to expire + sleep(MOVING_TIMEOUT + 0.5) + + # Now get several new connections after expiration + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with original address (no temporary settings) + assert new_connection.orig_host_address == DEFAULT_ADDRESS.split(":")[0] + assert new_connection.orig_socket_timeout is None + # New connections should be connected to the original address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + # Socket timeout should be None (original timeout) + assert new_connection._sock.gettimeout() is None + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_receive_migrated_after_moving(self, pool_class): + """ + Test receiving MIGRATED event after MOVING event. + + This test validates the complete MOVING -> MIGRATED lifecycle: + 1. MOVING event is processed and temporary settings are applied + 2. MIGRATED event is received during command execution + 3. Temporary settings are cleared after MIGRATED + 4. Pool configuration is restored to original values + + Note: When MIGRATED comes after MOVING and MOVING hasn't yet expired, + it should not decrease timeouts (future refactoring consideration). + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Step 1: Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result_moving = test_redis_client.set(key_moving, value_moving) + + # Validate MOVING command result + assert result_moving is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + ) + + # TODO validate current socket timeout + + # Step 2: Run command that will receive and handle MIGRATED event + # This should clear the temporary settings + key_migrated = "key_receive_migrated_0" + value_migrated = "migrated_value" + result_migrated = test_redis_client.set(key_migrated, value_migrated) + + # Validate MIGRATED command result + assert result_migrated is True, "SET key_receive_migrated command failed" + + # Step 3: Validate that MIGRATED event was processed but MOVING settings remain + # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) + # MOVING settings should still be active + # MOVING timeout should still be active + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + ) + + # Step 4: Create new connections after MIGRATED to verify they still use MOVING settings + # (since MOVING settings are still active) + new_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + new_connections.append(connection) + + # Validate that new connections are created with MOVING settings (still active) + for connection in new_connections: + assert connection.host == AFTER_MOVING_ADDRESS.split(":")[0] + # Note: New connections may not inherit the exact relax timeout value + # but they should have the temporary host address + # New connections should be connected + if connection._sock is not None: + assert connection._sock.connected is True + + # Release the new connections + for connection in new_connections: + test_redis_client.connection_pool.release(connection) + + # Validate free connections state with MOVING settings still active + # Note: We'll validate with the pool's current settings rather than individual connection settings + # since new connections may have different timeout values but still use the temporary address + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_overlapping_moving_events(self, pool_class): + """ + Test handling of overlapping/duplicate MOVING events (e.g., two MOVING events before the first expires). + Ensures that the second MOVING event updates the pool and connections as expected, and that expiry/cleanup works. + """ + global AFTER_MOVING_ADDRESS + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + try: + # Create and release some connections + in_use_connections = [] + for _ in range(3): + in_use_connections.append( + test_redis_client.connection_pool.get_connection() + ) + + for conn in in_use_connections: + test_redis_client.connection_pool.release(conn) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = test_redis_client.connection_pool.get_connection() + in_use_connections.append(conn) + + # Trigger first MOVING event + key_moving1 = "key_receive_moving_0" + value_moving1 = "value3_0" + result1 = test_redis_client.set(key_moving1, value_moving1) + assert result1 is True + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # Validate all connections reflect the first MOVING event + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_free_connections_state( + pool=test_redis_client.connection_pool, + should_be_connected_count=1, + connected_to_tmp_address=True, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # Reconnect in use connections + for conn in in_use_connections: + conn.disconnect() + conn.connect() + + # Before the first MOVING expires, trigger a second MOVING event (simulate new address) + # Validate the orig properties are not changed! + second_moving_address = "5.6.7.8:6380" + orig_after_moving = AFTER_MOVING_ADDRESS + # Temporarily modify the global constant for this test + AFTER_MOVING_ADDRESS = second_moving_address + try: + key_moving2 = "key_receive_moving_1" + value_moving2 = "value3_1" + result2 = test_redis_client.set(key_moving2, value_moving2) + assert result2 is True + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=second_moving_address.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # Validate all connections reflect the second MOVING event + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=orig_after_moving.split(":")[0], + ) + # print(test_redis_client.connection_pool._available_connections) + Helpers.validate_free_connections_state( + test_redis_client.connection_pool, + should_be_connected_count=1, + connected_to_tmp_address=True, + tmp_address=second_moving_address.split(":")[0], + expected_state=MaintenanceState.MOVING, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + finally: + AFTER_MOVING_ADDRESS = orig_after_moving + + # Wait for both MOVING timeouts to expire + sleep(MOVING_TIMEOUT + 0.5) + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_thread_safety_concurrent_event_handling(self, pool_class): + """ + Test thread-safety under concurrent maintenance event handling. + Simulates multiple threads triggering MOVING events and performing operations concurrently. + """ + import threading + + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + results = [] + errors = [] + + def worker(idx): + try: + key = f"key_receive_moving_{idx}" + value = f"value3_{idx}" + result = test_redis_client.set(key, value) + results.append(result) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert all(results), f"Not all threads succeeded: {results}" + assert not errors, f"Errors occurred in threads: {errors}" + # After all threads, MOVING event should have been handled safely + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): + """ + Test moving configs are not lost if the per connection events get picked up after moving is handled. + MOVING → MIGRATING → MIGRATED → MOVED + Checks the state after each event for all connections and for new connections created during each state. + """ + # Setup + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + pool = test_redis_client.connection_pool + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + in_use_connections = [] + for _ in range(3): + in_use_connections.append(pool.get_connection()) + while len(in_use_connections) > 0: + pool.release(in_use_connections.pop()) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = pool.get_connection() + in_use_connections.append(conn) + + # 1. MOVING event + tmp_address = "22.23.24.25" + moving_event = NodeMovingEvent( + id=1, new_node_host=tmp_address, new_node_port=6379, ttl=1 + ) + pool_handler.handle_event(moving_event) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_free_connections_state( + pool=pool, + should_be_connected_count=0, + connected_to_tmp_address=False, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + + # 2. MIGRATING event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratingEvent(id=2, ttl=1) + ) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 3. MIGRATED event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratedEvent(id=2) + ) + # State should not change for connections that are in MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 4. MOVED event (simulate timer expiry) + pool_handler.handle_node_moved_event(moving_event) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_free_connections_state( + pool=pool, + should_be_connected_count=0, + connected_to_tmp_address=False, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # New connection after MOVED + new_conn_none = pool.get_connection() + assert new_conn_none.maintenance_state == MaintenanceState.NONE + pool.release(new_conn_none) + # Cleanup + for conn in in_use_connections: + pool.release(conn) + if hasattr(pool, "disconnect"): + pool.disconnect() + + +class TestMaintenanceEventsHandlingMultipleProxies: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + self.orig_host = "test.address.com" + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if hasattr(sock, "pending_responses") and sock.pending_responses: + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + ips = ["1.2.3.4", "5.6.7.8", "9.10.11.12"] + ips = ips * 3 + + # Mock socket creation to return our mock sockets + def mock_socket_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + if host == self.orig_host: + ip_address = ips.pop(0) + else: + ip_address = host + + # Return the standard getaddrinfo format + # (family, type, proto, canonname, sockaddr) + return [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + (ip_address, port), + ) + ] + + self.getaddrinfo_patcher = patch( + "socket.getaddrinfo", side_effect=mock_socket_getaddrinfo + ) + self.getaddrinfo_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + self.getaddrinfo_patcher.stop() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_after_moving_multiple_proxies(self, pool_class): + """ """ + # Setup + + pool = pool_class( + host=self.orig_host, + port=12345, + max_connections=10, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + pool.set_maintenance_events_pool_handler( + MaintenanceEventPoolHandler(pool, self.config) + ) + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + key1 = "1.2.3.4" + key2 = "5.6.7.8" + key3 = "9.10.11.12" + in_use_connections = {key1: [], key2: [], key3: []} + # Create 7 connections + for _ in range(7): + conn = pool.get_connection() + in_use_connections[conn.getpeername()].append(conn) + + for _, conns in in_use_connections.items(): + while len(conns) > 1: + pool.release(conns.pop()) + + # Send MOVING event to con with ip = key1 + conn = in_use_connections[key1][0] + pool_handler.set_connection(conn) + new_ip = "13.14.15.16" + pool_handler.handle_event( + NodeMovingEvent(id=1, new_node_host=new_ip, new_node_port=6379, ttl=1) + ) + + # validate in use connection and ip1 + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key1, + ) + # validate free connections for ip1 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + else: + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.host == self.orig_host + assert conn.socket_timeout is None + assert conn.socket_connect_timeout is None + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + assert changed_free_connections == 2 + assert len(free_connections) == 4 + + # Send second MOVING event to con with ip = key2 + conn = in_use_connections[key2][0] + pool_handler.set_connection(conn) + new_ip_2 = "17.18.19.20" + pool_handler.handle_event( + NodeMovingEvent(id=2, new_node_host=new_ip_2, new_node_port=6379, ttl=2) + ) + + # validate in use connection and ip2 + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + # validate free connections for ip2 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip_2: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip_2 + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + # here I can't validate the other connections since some of + # them are in MOVING state from the first event + # and some are in NONE state + assert changed_free_connections == 1 + + # MIGRATING event on connection that has already been marked as MOVING + conn = in_use_connections[key2][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection does not lose its MOVING state + assert conn.maintenance_state == MaintenanceState.MOVING + # MIGRATED event + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection does not lose its MOVING state and relax timeout + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.socket_timeout == self.config.relax_timeout + + # Send Migrating event to con with ip = key3 + conn = in_use_connections[key3][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection is in MIGRATING state + assert conn.maintenance_state == MaintenanceState.MIGRATING + assert conn.socket_timeout == self.config.relax_timeout + + # Send MIGRATED event to con with ip = key3 + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection is in MOVING state + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.socket_timeout is None + + # sleep to expire only the first MOVING events + sleep(1.3) + # validate only the connections affected by the first MOVING event + # have lost their MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.NONE, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key1, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key3], + expected_state=MaintenanceState.NONE, + expected_should_reconnect=False, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key3, + ) + # TODO validate free connections + + # sleep to expire the second MOVING events + sleep(1) + # validate all connections have lost their MOVING state + Helpers.validate_in_use_connections_state( + [ + *in_use_connections[key1], + *in_use_connections[key2], + *in_use_connections[key3], + ], + expected_state=MaintenanceState.NONE, + expected_should_reconnect="any", + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername="any", + ) + # TODO validate free connections From 6f139535898900038eddf1385d7180d08a741818 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 19 Aug 2025 12:08:12 +0300 Subject: [PATCH 2/4] Adding handling of FAILING_OVER and FAILED_OVER events/push notifications (#3716) --- redis/maintenance_events.py | 133 ++++++++++++++++-- tests/test_maintenance_events.py | 162 +++++++++++++++++++--- tests/test_maintenance_events_handling.py | 127 ++++++++++++++++- 3 files changed, 385 insertions(+), 37 deletions(-) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 8c6c15c74a..09c767be63 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -11,7 +11,7 @@ class MaintenanceState(enum.Enum): NONE = "none" MOVING = "moving" - MIGRATING = "migrating" + MAINTENANCE = "maintenance" if TYPE_CHECKING: @@ -261,6 +261,105 @@ def __hash__(self) -> int: return hash((self.__class__, self.id)) +class NodeFailingOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of failing over. + + This event is received when a node starts a failover process during + cluster maintenance operations or when handling node failures. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailingOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailingOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class NodeFailedOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed a failover. + + This event is received when a node has finished the failover process + during cluster maintenance operations or after handling node failures. + + Args: + id (int): Unique identifier for this event + """ + + DEFAULT_TTL = 5 + + def __init__(self, id: int): + super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailedOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailedOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + class MaintenanceEventsConfig: """ Configuration class for maintenance events handling behaviour. Events are received through @@ -457,6 +556,14 @@ def handle_node_moved_event(self, event: NodeMovingEvent): class MaintenanceEventConnectionHandler: + # 1 = "starting maintenance" events, 0 = "completed maintenance" events + _EVENT_TYPES: dict[type["MaintenanceEvent"], int] = { + NodeMigratingEvent: 1, + NodeFailingOverEvent: 1, + NodeMigratedEvent: 0, + NodeFailedOverEvent: 0, + } + def __init__( self, connection: "ConnectionInterface", config: MaintenanceEventsConfig ) -> None: @@ -464,25 +571,31 @@ def __init__( self.config = config def handle_event(self, event: MaintenanceEvent): - if isinstance(event, NodeMigratingEvent): - return self.handle_migrating_event(event) - elif isinstance(event, NodeMigratedEvent): - return self.handle_migration_completed_event(event) - else: + # get the event type by checking its class in the _EVENT_TYPES dict + event_type = self._EVENT_TYPES.get(event.__class__, None) + + if event_type is None: logging.error(f"Unhandled event type: {event}") + return - def handle_migrating_event(self, notification: NodeMigratingEvent): + if event_type: + self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) + else: + self.handle_maintenance_completed_event() + + def handle_maintenance_start_event(self, maintenance_state: MaintenanceState): if ( self.connection.maintenance_state == MaintenanceState.MOVING or not self.config.is_relax_timeouts_enabled() ): return - self.connection.maintenance_state = MaintenanceState.MIGRATING + + self.connection.maintenance_state = maintenance_state self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) - def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + def handle_maintenance_completed_event(self): # Only reset timeouts if state is not MOVING and relax timeouts are enabled if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -490,7 +603,7 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): ): return self.connection.reset_tmp_settings(reset_relax_timeout=True) - # Node migration completed - reset the connection + # Maintenance completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index c90fa5db4f..30169615cf 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -7,9 +7,12 @@ NodeMovingEvent, NodeMigratingEvent, NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventsConfig, MaintenanceEventPoolHandler, MaintenanceEventConnectionHandler, + MaintenanceState, ) @@ -281,6 +284,84 @@ def test_equality_and_hash(self): assert hash(event1) != hash(event3) +class TestNodeFailingOverEvent: + """Test the NodeFailingOverEvent class.""" + + def test_init(self): + """Test NodeFailingOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeFailingOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeFailingOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailingOverEvent.""" + event1 = NodeFailingOverEvent(id=1, ttl=5) + event2 = NodeFailingOverEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeFailingOverEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeFailedOverEvent: + """Test the NodeFailedOverEvent class.""" + + def test_init(self): + """Test NodeFailedOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeFailedOverEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeFailedOverEvent.DEFAULT_TTL == 5 + event = NodeFailedOverEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeFailedOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeFailedOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailedOverEvent.""" + event1 = NodeFailedOverEvent(id=1) + event2 = NodeFailedOverEvent(id=1) # Same id + event3 = NodeFailedOverEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + class TestMaintenanceEventsConfig: """Test the MaintenanceEventsConfig class.""" @@ -477,19 +558,41 @@ def test_handle_event_migrating(self): """Test handling of NodeMigratingEvent.""" event = NodeMigratingEvent(id=1, ttl=5) - with patch.object(self.handler, "handle_migrating_event") as mock_handle: + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) def test_handle_event_migrated(self): """Test handling of NodeMigratedEvent.""" event = NodeMigratedEvent(id=1) with patch.object( - self.handler, "handle_migration_completed_event" + self.handler, "handle_maintenance_completed_event" ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with() + + def test_handle_event_failing_over(self): + """Test handling of NodeFailingOverEvent.""" + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) + + def test_handle_event_failed_over(self): + """Test handling of NodeFailedOverEvent.""" + event = NodeFailedOverEvent(id=1) + + with patch.object( + self.handler, "handle_maintenance_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with() def test_handle_event_unknown_type(self): """Test handling of unknown event type.""" @@ -500,42 +603,61 @@ def test_handle_event_unknown_type(self): result = self.handler.handle_event(event) assert result is None - def test_handle_migrating_event_disabled(self): - """Test migrating event handling when relax timeouts are disabled.""" + def test_handle_maintenance_start_event_disabled(self): + """Test maintenance start event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratingEvent(id=1, ttl=5) - result = handler.handle_migrating_event(event) + result = handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migrating_event_success(self): - """Test successful migrating event handling.""" - event = NodeMigratingEvent(id=1, ttl=5) + def test_handle_maintenance_start_event_moving_state(self): + """Test maintenance start event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING + + result = self.handler.handle_maintenance_start_event( + MaintenanceState.MAINTENANCE + ) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_maintenance_start_event_success(self): + """Test successful maintenance start event handling for migrating.""" + self.mock_connection.maintenance_state = MaintenanceState.NONE - self.handler.handle_migrating_event(event) + self.handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) + assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) self.mock_connection.set_tmp_settings.assert_called_once_with( tmp_relax_timeout=20 ) - def test_handle_migration_completed_event_disabled(self): - """Test migration completed event handling when relax timeouts are disabled.""" + def test_handle_maintenance_completed_event_disabled(self): + """Test maintenance completed event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratedEvent(id=1) - result = handler.handle_migration_completed_event(event) + result = handler.handle_maintenance_completed_event() assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migration_completed_event_success(self): - """Test successful migration completed event handling.""" - event = NodeMigratedEvent(id=1) + def test_handle_maintenance_completed_event_moving_state(self): + """Test maintenance completed event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING + + result = self.handler.handle_maintenance_completed_event() + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_maintenance_completed_event_success(self): + """Test successful maintenance completed event handling.""" + self.mock_connection.maintenance_state = MaintenanceState.MAINTENANCE + + self.handler.handle_maintenance_completed_event() - self.handler.handle_migration_completed_event(event) + assert self.mock_connection.maintenance_state == MaintenanceState.NONE self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) self.mock_connection.reset_tmp_settings.assert_called_once_with( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 8db8d182a7..ea0021c8a5 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -16,9 +16,11 @@ from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, + NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventPoolHandler, NodeMovingEvent, - NodeMigratedEvent, ) @@ -189,6 +191,22 @@ def send(self, data): # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" response = migrated_push.encode() + response + elif ( + b"key_receive_failing_over_" in data + or b"key_receive_failing_over" in data + ): + # FAILING_OVER push message before SET key_receive_failing_over_X response + # Format: >2\r\n$12\r\nFAILING_OVER\r\n:10\r\n (2 elements: FAILING_OVER, ttl) + failing_over_push = ">2\r\n$12\r\nFAILING_OVER\r\n:10\r\n" + response = failing_over_push.encode() + response + elif ( + b"key_receive_failed_over_" in data + or b"key_receive_failed_over" in data + ): + # FAILED_OVER push message before SET key_receive_failed_over_X response + # Format: >1\r\n$11\r\nFAILED_OVER\r\n (1 element: FAILED_OVER) + failed_over_push = ">1\r\n$11\r\nFAILED_OVER\r\n" + response = failed_over_push.encode() + response elif b"key_receive_moving_" in data: # MOVING push message before SET key_receive_moving_X response # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) @@ -211,6 +229,10 @@ def send(self, data): self.pending_responses.append(b"$6\r\nvalue2\r\n") elif b"key_receive_migrated" in data: self.pending_responses.append(b"$6\r\nvalue3\r\n") + elif b"key_receive_failing_over" in data: + self.pending_responses.append(b"$6\r\nvalue4\r\n") + elif b"key_receive_failed_over" in data: + self.pending_responses.append(b"$6\r\nvalue5\r\n") elif b"key1" in data: self.pending_responses.append(b"$6\r\nvalue1\r\n") else: @@ -727,13 +749,14 @@ def test_migration_related_events_handling_integration(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migrating_event_with_disabled_relax_timeout(self, pool_class): """ - Test migrating event handling when relax timeout is disabled. + Test maintenance events handling when relax timeout is disabled. This test validates that when relax_timeout is disabled (-1): - 1. MIGRATING events are received and processed + 1. MIGRATING, MIGRATED, FAILING_OVER, and FAILED_OVER events are received and processed 2. No timeout updates are applied to connections - 3. Socket timeouts remain unchanged during migration events + 3. Socket timeouts remain unchanged during all maintenance events 4. Tests both ConnectionPool and BlockingConnectionPool implementations + 5. Tests the complete lifecycle: MIGRATING -> MIGRATED -> FAILING_OVER -> FAILED_OVER """ # Create config with disabled relax timeout disabled_config = MaintenanceEventsConfig( @@ -776,6 +799,57 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" ) + # Command 4: This SET command will receive MIGRATED push message before response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result4 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 4 result + assert result4 is True, "Command 4 (SET key_receive_migrated) failed" + + # Validate timeout is still NOT updated after MIGRATED (relax is disabled) + self._validate_current_timeout(None) + + # Command 5: This SET command will receive FAILING_OVER push message before response + key_failing_over = "key_receive_failing_over" + value_failing_over = "value4" + result5 = test_redis_client.set(key_failing_over, value_failing_over) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_failing_over) failed" + + # Validate timeout is still NOT updated after FAILING_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 6: Another command to verify timeout remains unchanged during failover + result6 = test_redis_client.get(key_failing_over) + + # Validate Command 6 result + expected_value6 = value_failing_over.encode() + assert result6 == expected_value6, ( + f"Command 6 (GET key_receive_failing_over) failed. Expected: {expected_value6}, Got: {result6}" + ) + + # Command 7: This SET command will receive FAILED_OVER push message before response + key_failed_over = "key_receive_failed_over" + value_failed_over = "value5" + result7 = test_redis_client.set(key_failed_over, value_failed_over) + + # Validate Command 7 result + assert result7 is True, "Command 7 (SET key_receive_failed_over) failed" + + # Validate timeout is still NOT updated after FAILED_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 8: Final command to verify timeout remains unchanged after all events + result8 = test_redis_client.get(key_failed_over) + + # Validate Command 8 result + expected_value8 = value_failed_over.encode() + assert result8 == expected_value8, ( + f"Command 8 (GET key_receive_failed_over) failed. Expected: {expected_value8}, Got: {result8}" + ) + # Verify maintenance events were processed correctly # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( @@ -1357,7 +1431,8 @@ def worker(idx): def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): """ Test moving configs are not lost if the per connection events get picked up after moving is handled. - MOVING → MIGRATING → MIGRATED → MOVED + Sequence of events: MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER, MOVED. + Note: FAILING_OVER and FAILED_OVER events do not change the connection state when already in MOVING state. Checks the state after each event for all connections and for new connections created during each state. """ # Setup @@ -1448,7 +1523,45 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - # 4. MOVED event (simulate timer expiry) + # 4. FAILING_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailingOverEvent(id=3, ttl=1) + ) + # State should not change for connections that are in MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 5. FAILED_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailedOverEvent(id=3) + ) + # State should not change for connections that are in MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 6. MOVED event (simulate timer expiry) pool_handler.handle_node_moved_event(moving_event) Helpers.validate_in_use_connections_state( in_use_connections, @@ -1695,7 +1808,7 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): conn_event_handler = conn._maintenance_event_connection_handler conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) # validate connection is in MIGRATING state - assert conn.maintenance_state == MaintenanceState.MIGRATING + assert conn.maintenance_state == MaintenanceState.MAINTENANCE assert conn.socket_timeout == self.config.relax_timeout # Send MIGRATED event to con with ip = key3 From f2c677a85f20fb863cdd2e8674fa22dbd2bf8a52 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Wed, 20 Aug 2025 07:58:05 +0100 Subject: [PATCH 3/4] Hitless upgrade: Adding handshake command to enable the notifications after connection is established (#3735) --- redis/_parsers/base.py | 119 ++++++++++------ redis/connection.py | 164 +++++++++++++++++---- redis/maintenance_events.py | 120 +++++++++++++++- tests/test_maintenance_events.py | 166 +++++++++++++++++++++- tests/test_maintenance_events_handling.py | 22 +-- 5 files changed, 509 insertions(+), 82 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index dd2d8b9de0..65472ea3f7 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -1,3 +1,4 @@ +import logging import sys from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError @@ -56,6 +57,8 @@ "Client sent AUTH, but no password is set": AuthenticationError, } +logger = logging.getLogger(__name__) + class BaseParser(ABC): EXCEPTION_CLASSES = { @@ -199,31 +202,42 @@ def handle_push_response(self, response, **kwargs): *_MOVING_MESSAGE, ): return self.pubsub_push_handler_func(response) - if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: - return self.invalidation_push_handler_func(response) - if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: - # TODO: PARSE latest format when available - host, port = response[2].decode().split(":") - ttl = response[1] - id = 1 # Hardcoded value until the notification starts including the id - notification = NodeMovingEvent(id, host, port, ttl) - return self.node_moving_push_handler_func(notification) - if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: - if msg_type in _MIGRATING_MESSAGE: - # TODO: PARSE latest format when available - ttl = response[1] - id = 2 # Hardcoded value until the notification starts including the id - notification = NodeMigratingEvent(id, ttl) - elif msg_type in _MIGRATED_MESSAGE: - # TODO: PARSE latest format when available - id = 3 # Hardcoded value until the notification starts including the id - notification = NodeMigratedEvent(id) - else: + + try: + if ( + msg_type in _INVALIDATION_MESSAGE + and self.invalidation_push_handler_func + ): + return self.invalidation_push_handler_func(response) + + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # Expected message format is: MOVING