@@ -282,6 +282,13 @@ def should_reconnect(self):
282282 """
283283 pass
284284
285+ @abstractmethod
286+ def get_resolved_ip (self ):
287+ """
288+ Get resolved ip address for the connection.
289+ """
290+ pass
291+
285292 @abstractmethod
286293 def update_current_socket_timeout (self , relax_timeout : Optional [float ] = None ):
287294 """
@@ -421,32 +428,16 @@ def __init__(
421428 parser_class = _RESP3Parser
422429 self .set_parser (parser_class )
423430
424- if maintenance_events_config and maintenance_events_config .enabled :
425- if maintenance_events_pool_handler :
426- maintenance_events_pool_handler .set_connection (self )
427- self ._parser .set_node_moving_push_handler (
428- maintenance_events_pool_handler .handle_event
429- )
430- self ._maintenance_event_connection_handler = (
431- MaintenanceEventConnectionHandler (self , maintenance_events_config )
432- )
433- self ._parser .set_maintenance_push_handler (
434- self ._maintenance_event_connection_handler .handle_event
435- )
431+ self .maintenance_events_config = maintenance_events_config
432+
433+ # Set up maintenance events if enabled
434+ self ._configure_maintenance_events (
435+ maintenance_events_pool_handler ,
436+ orig_host_address ,
437+ orig_socket_timeout ,
438+ orig_socket_connect_timeout ,
439+ )
436440
437- self .orig_host_address = (
438- orig_host_address if orig_host_address else self .host
439- )
440- self .orig_socket_timeout = (
441- orig_socket_timeout if orig_socket_timeout else self .socket_timeout
442- )
443- self .orig_socket_connect_timeout = (
444- orig_socket_connect_timeout
445- if orig_socket_connect_timeout
446- else self .socket_connect_timeout
447- )
448- else :
449- self ._maintenance_event_connection_handler = None
450441 self ._should_reconnect = False
451442 self .maintenance_state = maintenance_state
452443
@@ -505,6 +496,46 @@ def set_parser(self, parser_class):
505496 """
506497 self ._parser = parser_class (socket_read_size = self ._socket_read_size )
507498
499+ def _configure_maintenance_events (
500+ self ,
501+ maintenance_events_pool_handler = None ,
502+ orig_host_address = None ,
503+ orig_socket_timeout = None ,
504+ orig_socket_connect_timeout = None ,
505+ ):
506+ """Enable maintenance events by setting up handlers and storing original connection parameters."""
507+ if (
508+ not self .maintenance_events_config
509+ or not self .maintenance_events_config .enabled
510+ ):
511+ self ._maintenance_event_connection_handler = None
512+ return
513+
514+ # Set up pool handler if available
515+ if maintenance_events_pool_handler :
516+ self ._parser .set_node_moving_push_handler (
517+ maintenance_events_pool_handler .handle_event
518+ )
519+
520+ # Set up connection handler
521+ self ._maintenance_event_connection_handler = MaintenanceEventConnectionHandler (
522+ self , self .maintenance_events_config
523+ )
524+ self ._parser .set_maintenance_push_handler (
525+ self ._maintenance_event_connection_handler .handle_event
526+ )
527+
528+ # Store original connection parameters
529+ self .orig_host_address = orig_host_address if orig_host_address else self .host
530+ self .orig_socket_timeout = (
531+ orig_socket_timeout if orig_socket_timeout else self .socket_timeout
532+ )
533+ self .orig_socket_connect_timeout = (
534+ orig_socket_connect_timeout
535+ if orig_socket_connect_timeout
536+ else self .socket_connect_timeout
537+ )
538+
508539 def set_maintenance_event_pool_handler (
509540 self , maintenance_event_pool_handler : MaintenanceEventPoolHandler
510541 ):
@@ -652,6 +683,39 @@ def on_connect_check_health(self, check_health: bool = True):
652683 ):
653684 raise ConnectionError ("Invalid RESP version" )
654685
686+ # Send maintenance notifications handshake if RESP3 is active and maintenance events are enabled
687+ # and we have a host to determine the endpoint type from
688+ if (
689+ self .protocol not in [2 , "2" ]
690+ and self .maintenance_events_config
691+ and self .maintenance_events_config .enabled
692+ and self ._maintenance_event_connection_handler
693+ and hasattr (self , "host" )
694+ ):
695+ try :
696+ endpoint_type = self .maintenance_events_config .get_endpoint_type (
697+ self .host , self
698+ )
699+ self .send_command (
700+ "CLIENT" ,
701+ "MAINT_NOTIFICATIONS" ,
702+ "ON" ,
703+ "moving-endpoint-type" ,
704+ endpoint_type .value ,
705+ check_health = check_health ,
706+ )
707+ response = self .read_response ()
708+ if str_if_bytes (response ) != "OK" :
709+ raise ConnectionError (
710+ "The server doesn't support maintenance notifications"
711+ )
712+ except Exception as e :
713+ # Log warning but don't fail the connection
714+ import logging
715+
716+ logger = logging .getLogger (__name__ )
717+ logger .warning (f"Failed to enable maintenance notifications: { e } " )
718+
655719 # if a client_name is given, set it
656720 if self .client_name :
657721 self .send_command (
@@ -888,6 +952,56 @@ def re_auth(self):
888952 self .read_response ()
889953 self ._re_auth_token = None
890954
955+ def get_resolved_ip (self ) -> Optional [str ]:
956+ """
957+ Extract the resolved IP address from an
958+ established connection or resolve it from the host.
959+
960+ First tries to get the actual IP from the socket (most accurate),
961+ then falls back to DNS resolution if needed.
962+
963+ Args:
964+ connection: The connection object to extract the IP from
965+
966+ Returns:
967+ str: The resolved IP address, or None if it cannot be determined
968+ """
969+
970+ # Method 1: Try to get the actual IP from the established socket connection
971+ # This is most accurate as it shows the exact IP being used
972+ try :
973+ if self ._sock is not None :
974+ peer_addr = self ._sock .getpeername ()
975+ if peer_addr and len (peer_addr ) >= 1 :
976+ # For TCP sockets, peer_addr is typically (host, port) tuple
977+ # Return just the host part
978+ return peer_addr [0 ]
979+ except (AttributeError , OSError ):
980+ # Socket might not be connected or getpeername() might fail
981+ pass
982+
983+ # Method 2: Fallback to DNS resolution of the host
984+ # This is less accurate but works when socket is not available
985+ try :
986+ host = getattr (self , "host" , "localhost" )
987+ port = getattr (self , "port" , 6379 )
988+ if host :
989+ # Use getaddrinfo to resolve the hostname to IP
990+ # This mimics what the connection would do during _connect()
991+ addr_info = socket .getaddrinfo (
992+ host , port , socket .AF_UNSPEC , socket .SOCK_STREAM
993+ )
994+ if addr_info :
995+ # Return the IP from the first result
996+ # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
997+ # sockaddr[0] is the IP address
998+ return addr_info [0 ][4 ][0 ]
999+ except (AttributeError , OSError , socket .gaierror ):
1000+ # DNS resolution might fail
1001+ pass
1002+
1003+ return None
1004+
8911005 @property
8921006 def maintenance_state (self ) -> MaintenanceState :
8931007 return self ._maintenance_state
0 commit comments