@@ -293,6 +293,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
293293
294294 async def connect (self ):
295295 """Connects to the Redis server if not already connected"""
296+ await self .connect_check_health (check_health = True )
297+
298+ async def connect_check_health (self , check_health : bool = True ):
296299 if self .is_connected :
297300 return
298301 try :
@@ -311,7 +314,7 @@ async def connect(self):
311314 try :
312315 if not self .redis_connect_func :
313316 # Use the default on_connect function
314- await self .on_connect ( )
317+ await self .on_connect_check_health ( check_health = check_health )
315318 else :
316319 # Use the passed function redis_connect_func
317320 (
@@ -350,6 +353,9 @@ def get_protocol(self):
350353
351354 async def on_connect (self ) -> None :
352355 """Initialize the connection, authenticate and select a database"""
356+ await self .on_connect_check_health (check_health = True )
357+
358+ async def on_connect_check_health (self , check_health : bool = True ) -> None :
353359 self ._parser .on_connect (self )
354360 parser = self ._parser
355361
@@ -407,7 +413,7 @@ async def on_connect(self) -> None:
407413 # update cluster exception classes
408414 self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
409415 self ._parser .on_connect (self )
410- await self .send_command ("HELLO" , self .protocol )
416+ await self .send_command ("HELLO" , self .protocol , check_health = check_health )
411417 response = await self .read_response ()
412418 # if response.get(b"proto") != self.protocol and response.get(
413419 # "proto"
@@ -416,18 +422,35 @@ async def on_connect(self) -> None:
416422
417423 # if a client_name is given, set it
418424 if self .client_name :
419- await self .send_command ("CLIENT" , "SETNAME" , self .client_name )
425+ await self .send_command (
426+ "CLIENT" ,
427+ "SETNAME" ,
428+ self .client_name ,
429+ check_health = check_health ,
430+ )
420431 if str_if_bytes (await self .read_response ()) != "OK" :
421432 raise ConnectionError ("Error setting client name" )
422433
423434 # set the library name and version, pipeline for lower startup latency
424435 if self .lib_name :
425- await self .send_command ("CLIENT" , "SETINFO" , "LIB-NAME" , self .lib_name )
436+ await self .send_command (
437+ "CLIENT" ,
438+ "SETINFO" ,
439+ "LIB-NAME" ,
440+ self .lib_name ,
441+ check_health = check_health ,
442+ )
426443 if self .lib_version :
427- await self .send_command ("CLIENT" , "SETINFO" , "LIB-VER" , self .lib_version )
444+ await self .send_command (
445+ "CLIENT" ,
446+ "SETINFO" ,
447+ "LIB-VER" ,
448+ self .lib_version ,
449+ check_health = check_health ,
450+ )
428451 # if a database is specified, switch to it. Also pipeline this
429452 if self .db :
430- await self .send_command ("SELECT" , self .db )
453+ await self .send_command ("SELECT" , self .db , check_health = check_health )
431454
432455 # read responses from pipeline
433456 for _ in (sent for sent in (self .lib_name , self .lib_version ) if sent ):
@@ -489,8 +512,8 @@ async def send_packed_command(
489512 self , command : Union [bytes , str , Iterable [bytes ]], check_health : bool = True
490513 ) -> None :
491514 if not self .is_connected :
492- await self .connect ( )
493- elif check_health :
515+ await self .connect_check_health ( check_health = False )
516+ if check_health :
494517 await self .check_health ()
495518
496519 try :
0 commit comments